summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
Diffstat (limited to 'source')
-rw-r--r--source/slang/compiler.h17
-rw-r--r--source/slang/core.meta.slang10
-rw-r--r--source/slang/core.meta.slang.h10
-rw-r--r--source/slang/emit.cpp20
-rw-r--r--source/slang/ir.cpp12
-rw-r--r--source/slang/lower-to-ir.cpp38
-rw-r--r--source/slang/syntax.cpp29
7 files changed, 93 insertions, 43 deletions
diff --git a/source/slang/compiler.h b/source/slang/compiler.h
index 303be6624..a48ad2287 100644
--- a/source/slang/compiler.h
+++ b/source/slang/compiler.h
@@ -442,9 +442,24 @@ namespace Slang
// Should not be used in front-end code
Type* getIRBasicBlockType();
- // Construct pointer types on-demand
+ // Construct the type `Ptr<valueType>`, where `Ptr`
+ // is looked up as a builtin type.
RefPtr<PtrType> getPtrType(RefPtr<Type> valueType);
+ // Construct the type `Out<valueType>`
+ RefPtr<OutType> getOutType(RefPtr<Type> valueType);
+
+ // Construct the type `InOut<valueType>`
+ RefPtr<InOutType> getInOutType(RefPtr<Type> valueType);
+
+ // Construct a pointer type like `Ptr<valueType>`, but where
+ // the actual type name for the pointer type is given by `ptrTypeName`
+ RefPtr<PtrTypeBase> getPtrType(RefPtr<Type> valueType, char const* ptrTypeName);
+
+ // Construct a pointer type like `Ptr<valueType>`, but where
+ // the generic declaration for the pointer type is `genericDecl`
+ RefPtr<PtrTypeBase> getPtrType(RefPtr<Type> valueType, GenericDecl* genericDecl);
+
RefPtr<ArrayExpressionType> getArrayType(
Type* elementType,
IntVal* elementCount);
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index f36b53227..85137482d 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -101,6 +101,16 @@ __magic_type(PtrType)
struct Ptr
{};
+__generic<T>
+__magic_type(OutType)
+struct Out
+{};
+
+__generic<T>
+__magic_type(InOutType)
+struct InOut
+{};
+
${{{{
diff --git a/source/slang/core.meta.slang.h b/source/slang/core.meta.slang.h
index 1c0f28b26..8644c7c90 100644
--- a/source/slang/core.meta.slang.h
+++ b/source/slang/core.meta.slang.h
@@ -102,6 +102,16 @@ sb << "__magic_type(PtrType)\n";
sb << "struct Ptr\n";
sb << "{};\n";
sb << "\n";
+sb << "__generic<T>\n";
+sb << "__magic_type(OutType)\n";
+sb << "struct Out\n";
+sb << "{};\n";
+sb << "\n";
+sb << "__generic<T>\n";
+sb << "__magic_type(InOutType)\n";
+sb << "struct InOut\n";
+sb << "{};\n";
+sb << "\n";
sb << "";
diff --git a/source/slang/emit.cpp b/source/slang/emit.cpp
index 614e8f474..383442236 100644
--- a/source/slang/emit.cpp
+++ b/source/slang/emit.cpp
@@ -5946,19 +5946,15 @@ emitDeclImpl(decl, nullptr);
// encoded as a parameter of pointer type, so
// we need to decode that here.
//
- if( auto ptrType = type->As<PtrType>() )
+ if( auto outType = type->As<OutType>() )
{
- // TODO: we need a way to distinguish `out`
- // from `inout`. The easiest way to do
- // that might be to have each be a distinct
- // sub-case of `IRPtrType` - this would also
- // ensure that they can be distinguished from
- // real pointers when the user means to use
- // them.
-
emit("out ");
-
- type = ptrType->getValueType();
+ type = outType->getValueType();
+ }
+ else if( auto inOutType = type->As<InOutType>() )
+ {
+ emit("inout ");
+ type = inOutType->getValueType();
}
emitIRType(ctx, type, name);
@@ -6595,7 +6591,7 @@ emitDeclImpl(decl, nullptr);
{
emitIRUsedType(ctx, genericType->elementType);
}
- else if( auto ptrType = type->As<PtrType>() )
+ else if( auto ptrType = type->As<PtrTypeBase>() )
{
emitIRUsedType(ctx, ptrType->getValueType());
}
diff --git a/source/slang/ir.cpp b/source/slang/ir.cpp
index bfc26643c..2d3127a61 100644
--- a/source/slang/ir.cpp
+++ b/source/slang/ir.cpp
@@ -956,7 +956,7 @@ namespace Slang
IRInst* IRBuilder::emitLoad(
IRValue* ptr)
{
- auto ptrType = ptr->getType()->As<PtrType>();
+ auto ptrType = ptr->getType()->As<PtrTypeBase>();
if( !ptrType )
{
// Bad!
@@ -2849,12 +2849,10 @@ namespace Slang
builder.curBlock = firstBlock;
builder.insertBeforeInst = firstBlock->getFirstInst();
- // TODO: We need to distinguish any true pointers in the
- // user's code from pointers that only exist for
- // parameter-passing. This `PtrType` here should actually
- // be `OutTypeBase`, but I'm not confident that all
- // the other code is handling that correctly...
- if(auto paramPtrType = paramType->As<PtrType>() )
+ // Is the parameter type a special pointer type
+ // that indicates the parameter is used for `out`
+ // or `inout` access?
+ if(auto paramPtrType = paramType->As<OutTypeBase>() )
{
// Okay, we have the more interesting case here,
// where the parameter was being passed by reference.
diff --git a/source/slang/lower-to-ir.cpp b/source/slang/lower-to-ir.cpp
index 0f3e85805..b395d2a95 100644
--- a/source/slang/lower-to-ir.cpp
+++ b/source/slang/lower-to-ir.cpp
@@ -819,14 +819,6 @@ IRType* getIntType(
return context->getSession()->getBuiltinType(BaseType::Int);
}
-// Get a pointer type to the given element type
-RefPtr<PtrType> getPtrType(
- IRGenContext* context,
- IRType* valueType)
-{
- return context->getSession()->getPtrType(valueType);
-}
-
RefPtr<IRFuncType> getFuncType(
IRGenContext* context,
UInt paramCount,
@@ -1089,7 +1081,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
RefPtr<Type> loweredBaseType = loweredBaseVal->getType();
if (loweredBaseType->As<PointerLikeType>()
- || loweredBaseType->As<PtrType>())
+ || loweredBaseType->As<PtrTypeBase>())
{
// Note that we do *not* perform an actual `load` operation
// here, but rather just use the pointer value to construct
@@ -1461,7 +1453,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
case LoweredValInfo::Flavor::Ptr:
return LoweredValInfo::ptr(
builder->emitElementAddress(
- getPtrType(context, getSimpleType(type)),
+ context->getSession()->getPtrType(getSimpleType(type)),
baseVal.val,
indexVal));
@@ -1498,7 +1490,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
IRValue* irBasePtr = base.val;
return LoweredValInfo::ptr(
getBuilder()->emitFieldAddress(
- getPtrType(context, getSimpleType(fieldType)),
+ context->getSession()->getPtrType(getSimpleType(fieldType)),
irBasePtr,
getBuilder()->getDeclRefVal(field)));
}
@@ -3114,14 +3106,22 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
paramTypes.Add(irParamType);
break;
- default:
- // The parameter is being used for input/output purposes,
- // so it will lower to an actual parameter with a pointer type.
- //
- // TODO: Is this the best representation we can use?
+ // If the parameter is declared `out` or `inout`,
+ // then we will represent it with a pointer type in
+ // the IR, but we will use a specialized pointer
+ // type that encodes the parameter direction information.
+ case kParameterDirection_Out:
+ paramTypes.Add(
+ context->getSession()->getOutType(irParamType));
+ break;
+ case kParameterDirection_InOut:
+ paramTypes.Add(
+ context->getSession()->getInOutType(irParamType));
+ break;
- auto irPtrType = getPtrType(context, irParamType);
- paramTypes.Add(irPtrType);
+ default:
+ SLANG_UNEXPECTED("unknown parameter direction");
+ break;
}
}
@@ -3190,7 +3190,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
//
// TODO: Is this the best representation we can use?
- auto irPtrType = irParamType.As<PtrType>();
+ auto irPtrType = irParamType.As<PtrTypeBase>();
IRParam* irParamPtr = subBuilder->emitParam(irPtrType);
if(auto paramDecl = paramInfo.decl)
diff --git a/source/slang/syntax.cpp b/source/slang/syntax.cpp
index fa9c88051..e43dd9074 100644
--- a/source/slang/syntax.cpp
+++ b/source/slang/syntax.cpp
@@ -280,19 +280,40 @@ void Type::accept(IValVisitor* visitor, void* extra)
RefPtr<PtrType> Session::getPtrType(
RefPtr<Type> valueType)
{
+ return getPtrType(valueType, "PtrType").As<PtrType>();
+ }
+
+ // Construct the type `Out<valueType>`
+ RefPtr<OutType> Session::getOutType(RefPtr<Type> valueType)
+ {
+ return getPtrType(valueType, "OutType").As<OutType>();
+ }
+
+ RefPtr<InOutType> Session::getInOutType(RefPtr<Type> valueType)
+ {
+ return getPtrType(valueType, "InOutType").As<InOutType>();
+ }
+
+ RefPtr<PtrTypeBase> Session::getPtrType(RefPtr<Type> valueType, char const* ptrTypeName)
+ {
auto genericDecl = findMagicDecl(
- this, "PtrType").As<GenericDecl>();
+ this, ptrTypeName).As<GenericDecl>();
+ return getPtrType(valueType, genericDecl);
+ }
+
+ RefPtr<PtrTypeBase> Session::getPtrType(RefPtr<Type> valueType, GenericDecl* genericDecl)
+ {
auto typeDecl = genericDecl->inner;
-
+
auto substitutions = new GenericSubstitution();
- substitutions->genericDecl = genericDecl.Ptr();
+ substitutions->genericDecl = genericDecl;
substitutions->args.Add(valueType);
auto declRef = DeclRef<Decl>(typeDecl.Ptr(), substitutions);
return DeclRefType::Create(
this,
- declRef)->As<PtrType>();
+ declRef)->As<PtrTypeBase>();
}
RefPtr<ArrayExpressionType> Session::getArrayType(