summaryrefslogtreecommitdiffstats
path: root/source/slang
diff options
context:
space:
mode:
authorTim Foley <tfoley@nvidia.com>2017-07-17 13:32:20 -0700
committerTim Foley <tfoley@nvidia.com>2017-07-17 13:35:22 -0700
commit453a9ca07417bbc17294267c5e44843d16e93c50 (patch)
tree3c07ad5576737423cd407772a7d23748eb67f090 /source/slang
parent77e3c3bfb1f77ec04cd8e63a676bfa3e2ae2f998 (diff)
Handle arrays when scalarizing "resources in structs"
The basic idea is that an array of `struct`s will get scalarized into per-field arrays (for any fields that need to be scalarized). So given: struct Foo { float x; Texture2D t; }; cbuffer C { Foo foo[4]; } We'll get output like: struct Foo { float x; }; cbuffer C { Foo foo[4]; } Texture2D C_foo_t[4]; (Of course the output would also be translated over to GLSL, but I'm only concerned about this one transformation here).
Diffstat (limited to 'source/slang')
-rw-r--r--source/slang/lower.cpp240
-rw-r--r--source/slang/syntax.h4
2 files changed, 200 insertions, 44 deletions
diff --git a/source/slang/lower.cpp b/source/slang/lower.cpp
index 98f6d8273..b6a19ab34 100644
--- a/source/slang/lower.cpp
+++ b/source/slang/lower.cpp
@@ -616,7 +616,7 @@ struct LoweringVisitor
result->tupleElements.Add(elem);
}
- return result;
+return result;
}
RefPtr<ExpressionSyntaxNode> visitVarExpressionSyntaxNode(
@@ -708,6 +708,84 @@ struct LoweringVisitor
return loweredExpr;
}
+ RefPtr<ExpressionType> getSubscripResultType(
+ RefPtr<ExpressionType> type)
+ {
+ if (auto arrayType = type->As<ArrayExpressionType>())
+ {
+ return arrayType->BaseType;
+ }
+ return nullptr;
+ }
+
+ RefPtr<ExpressionSyntaxNode> createSubscriptExpr(
+ RefPtr<ExpressionSyntaxNode> baseExpr,
+ RefPtr<ExpressionSyntaxNode> indexExpr)
+ {
+ // TODO: This logic ends up duplicating the `indexExpr`
+ // that was given, without worrying about any side
+ // effects it might contain. That needs to be fixed.
+
+ if (auto baseTuple = baseExpr.As<TupleExpr>())
+ {
+ auto loweredExpr = new TupleExpr();
+ loweredExpr->Type.type = getSubscripResultType(baseExpr->Type.type);
+
+ if (auto basePrimary = baseTuple->primaryExpr)
+ {
+ loweredExpr->primaryExpr = createSubscriptExpr(
+ basePrimary,
+ indexExpr);
+ }
+ for (auto elem : baseTuple->tupleElements)
+ {
+ TupleExpr::Element loweredElem;
+ loweredElem.tupleFieldDeclRef = elem.tupleFieldDeclRef;
+ loweredElem.expr = createSubscriptExpr(
+ elem.expr,
+ indexExpr);
+
+ loweredExpr->tupleElements.Add(loweredElem);
+ }
+
+ return loweredExpr;
+ }
+ else
+ {
+ // Default case: just reconstrut a subscript expr
+ auto loweredExpr = new IndexExpressionSyntaxNode();
+
+ loweredExpr->Type.type = getSubscripResultType(baseExpr->Type.type);
+
+ loweredExpr->BaseExpression = baseExpr;
+ loweredExpr->IndexExpression = indexExpr;
+ return loweredExpr;
+ }
+ }
+
+ RefPtr<ExpressionSyntaxNode> visitIndexExpressionSyntaxNode(
+ IndexExpressionSyntaxNode* subscriptExpr)
+ {
+ auto baseExpr = lowerExpr(subscriptExpr->BaseExpression);
+ auto indexExpr = lowerExpr(subscriptExpr->IndexExpression);
+
+ // An attempt to subscript a tuple must be turned into a
+ // tuple of subscript expressions.
+ if (auto baseTuple = baseExpr.As<TupleExpr>())
+ {
+ return createSubscriptExpr(baseExpr, indexExpr);
+ }
+ else
+ {
+ // Default case: just reconstrut a subscript expr
+ RefPtr<IndexExpressionSyntaxNode> loweredExpr = new IndexExpressionSyntaxNode();
+ lowerExprCommon(loweredExpr, subscriptExpr);
+ loweredExpr->BaseExpression = baseExpr;
+ loweredExpr->IndexExpression = indexExpr;
+ return loweredExpr;
+ }
+ }
+
void addArgs(
InvokeExpressionSyntaxNode* callExpr,
RefPtr<ExpressionSyntaxNode> argExpr)
@@ -1615,6 +1693,21 @@ struct LoweringVisitor
return nullptr;
}
+ ExpressionType* unwrapArray(ExpressionType* inType)
+ {
+ auto type = inType;
+ while (auto arrayType = type->As<ArrayExpressionType>())
+ {
+ type = arrayType->BaseType;
+ }
+ return type;
+ }
+
+ TupleTypeModifier* isTupleTypeOrArrayOfTupleType(ExpressionType* type)
+ {
+ return isTupleType(unwrapArray(type));
+ }
+
bool isResourceType(ExpressionType* type)
{
while (auto arrayType = type->As<ArrayExpressionType>())
@@ -1692,7 +1785,7 @@ struct LoweringVisitor
bool isTupleField = false;
bool fieldHasAnyNonTupleFields = false;
bool fieldHasTupleType = false;
- if (auto fieldTupleTypeMod = isTupleType(loweredFieldType))
+ if (auto fieldTupleTypeMod = isTupleTypeOrArrayOfTupleType(loweredFieldType))
{
isTupleField = true;
fieldHasTupleType = true;
@@ -1780,19 +1873,60 @@ struct LoweringVisitor
return lowerSimpleVarDeclCommon(loweredDecl, decl, loweredType);
}
+ struct TupleTypeSecondaryVarArraySpec
+ {
+ TupleTypeSecondaryVarArraySpec* next;
+ RefPtr<IntVal> elementCount;
+ };
+
+ struct TupleSecondaryVarInfo
+ {
+ // Parent tuple decl to add the secondary decl into
+ RefPtr<TupleVarDecl> tupleDecl;
+
+ // Syntax class for declarations to create
+ SyntaxClass<VarDeclBase> varDeclClass;
+
+ // Name "stem" to use for any actual variables we create
+ String name;
+
+ // The parent tuple type (or array thereof) we are scalarizing
+ RefPtr<ExpressionType> tupleType;
+
+ // The actual declaration of the tuple type (which will give us the fields)
+ DeclRef<AggTypeDecl> tupleTypeDecl;
+
+ // An initializer expression to use for the tuple members
+ RefPtr<ExpressionSyntaxNode> initExpr;
+
+ // The original layout given to the top-level variable
+ RefPtr<VarLayout> primaryVarLayout;
+
+ // The computed layout of the tuple type itself
+ RefPtr<StructTypeLayout> tupleTypeLayout;
+
+ TupleTypeSecondaryVarArraySpec* arraySpecs = nullptr;
+ };
+
void createTupleTypeSecondaryVarDecls(
- RefPtr<TupleVarDecl> tupleDecl,
- SyntaxClass<VarDeclBase> varDeclClass,
- String const& name,
- RefPtr<ExpressionType> tupleType,
- DeclRef<AggTypeDecl> tupleTypeDecl,
- RefPtr<ExpressionSyntaxNode> initExpr,
- RefPtr<VarLayout> primaryVarLayout,
- RefPtr<StructTypeLayout> tupleTypeLayout)
+ TupleSecondaryVarInfo const& info)
{
+ if (auto arrayType = info.tupleType->As<ArrayExpressionType>())
+ {
+ TupleTypeSecondaryVarArraySpec arraySpec;
+ arraySpec.next = info.arraySpecs;
+ arraySpec.elementCount = arrayType->ArrayLength;
+
+ TupleSecondaryVarInfo subInfo = info;
+ subInfo.tupleType = arrayType->BaseType;
+ subInfo.arraySpecs = &arraySpec;
+ createTupleTypeSecondaryVarDecls(subInfo);
+ return;
+ }
+
// Next, we need to go through the declarations in the aggregate
// type, and deal with all of those that should be tuple-ified.
- for (auto dd : getMembersOfType<VarDeclBase>(tupleTypeDecl))
+ for (auto dd : getMembersOfType<VarDeclBase>(info.tupleTypeDecl))
{
if (dd.getDecl()->HasModifier<HLSLStaticModifier>())
continue;
@@ -1802,10 +1936,10 @@ struct LoweringVisitor
continue;
// TODO: need to extract the initializer for this field
- assert(!initExpr);
+ assert(!info.initExpr);
RefPtr<ExpressionSyntaxNode> fieldInitExpr;
- String fieldName = name + "_" + dd.GetName();
+ String fieldName = info.name + "_" + dd.GetName();
auto fieldType = GetType(dd);
@@ -1814,11 +1948,11 @@ struct LoweringVisitor
assert(originalFieldDecl);
RefPtr<VarLayout> fieldLayout;
- if(tupleTypeLayout)
+ if(info.tupleTypeLayout)
{
- tupleTypeLayout->mapVarToLayout.TryGetValue(originalFieldDecl, fieldLayout);
+ info.tupleTypeLayout->mapVarToLayout.TryGetValue(originalFieldDecl, fieldLayout);
}
- if (fieldLayout && primaryVarLayout)
+ if (fieldLayout && info.primaryVarLayout)
{
// The layout for a field may need to be adjusted
// based on a base offset stored in the primary
@@ -1835,7 +1969,7 @@ struct LoweringVisitor
bool needsOffset = false;
for (auto rr : fieldLayout->resourceInfos)
{
- if (auto parentInfo = primaryVarLayout->FindResourceInfo(rr.kind))
+ if (auto parentInfo = info.primaryVarLayout->FindResourceInfo(rr.kind))
{
if (parentInfo->index != 0 || parentInfo->space != 0)
{
@@ -1858,7 +1992,7 @@ struct LoweringVisitor
auto newResInfo = newFieldLayout->findOrAddResourceInfo(resInfo.kind);
newResInfo->index = resInfo.index;
newResInfo->space = resInfo.space;
- if (auto parentInfo = primaryVarLayout->FindResourceInfo(resInfo.kind))
+ if (auto parentInfo = info.primaryVarLayout->FindResourceInfo(resInfo.kind))
{
newResInfo->index += parentInfo->index;
newResInfo->space += parentInfo->space;
@@ -1871,29 +2005,44 @@ struct LoweringVisitor
}
RefPtr<VarDeclBase> fieldVarOrTupleDecl;
- if (auto fieldTupleTypeMod = isTupleType(fieldType))
+ if (auto fieldTupleTypeMod = isTupleTypeOrArrayOfTupleType(fieldType))
{
// If the field is itself a tuple, then recurse
RefPtr<TupleVarDecl> fieldTupleDecl = new TupleVarDecl();
+
+ TupleSecondaryVarInfo fieldInfo;
+ fieldInfo.tupleDecl = fieldTupleDecl;
+ fieldInfo.varDeclClass = info.varDeclClass;
+ fieldInfo.name = fieldName;
+ fieldInfo.tupleType = fieldType;
+ fieldInfo.tupleTypeDecl = makeDeclRef(fieldTupleTypeMod->decl);
+ fieldInfo.initExpr = fieldInitExpr;
+ fieldInfo.primaryVarLayout = fieldLayout;
+ fieldInfo.tupleTypeLayout = getBodyStructTypeLayout(fieldLayout ? fieldLayout->typeLayout : nullptr);
+ fieldInfo.arraySpecs = info.arraySpecs;
+
fieldTupleDecl->tupleType = fieldTupleTypeMod;
- createTupleTypeSecondaryVarDecls(
- fieldTupleDecl,
- varDeclClass,
- fieldName,
- fieldType,
- makeDeclRef(fieldTupleTypeMod->decl),
- fieldInitExpr,
- fieldLayout,
- getBodyStructTypeLayout(fieldLayout ? fieldLayout->typeLayout : nullptr));
+ createTupleTypeSecondaryVarDecls(fieldInfo);
fieldVarOrTupleDecl = fieldTupleDecl;
}
else
{
// Otherwise the field has a simple type, and we just need to declare the variable here
- RefPtr<VarDeclBase> fieldVarDecl = varDeclClass.createInstance();
+
+ RefPtr<ExpressionType> fieldVarType = fieldType;
+ for (auto aa = info.arraySpecs; aa; aa = aa->next)
+ {
+ RefPtr<ArrayExpressionType> arrayType = new ArrayExpressionType();
+ arrayType->BaseType = fieldVarType;
+ arrayType->ArrayLength = aa->elementCount;
+
+ fieldVarType = arrayType;
+ }
+
+ RefPtr<VarDeclBase> fieldVarDecl = info.varDeclClass.createInstance();
fieldVarDecl->Name.Content = fieldName;
- fieldVarDecl->Type.type = fieldType;
+ fieldVarDecl->Type.type = fieldVarType;
addDecl(fieldVarDecl);
@@ -1911,7 +2060,7 @@ struct LoweringVisitor
fieldTupleVarMod->tupleField = tupleFieldMod;
addModifier(fieldVarOrTupleDecl, fieldTupleVarMod);
- tupleDecl->tupleDecls.Add(fieldVarOrTupleDecl);
+ info.tupleDecl->tupleDecls.Add(fieldVarOrTupleDecl);
}
}
@@ -1955,15 +2104,17 @@ struct LoweringVisitor
addDecl(primaryVarDecl);
}
- createTupleTypeSecondaryVarDecls(
- tupleDecl,
- varDeclClass,
- name,
- tupleType,
- tupleTypeDecl,
- initExpr,
- primaryVarLayout,
- tupleTypeLayout);
+ TupleSecondaryVarInfo info;
+ info.tupleDecl = tupleDecl;
+ info.varDeclClass = varDeclClass;
+ info.name = name;
+ info.tupleType = tupleType;
+ info.tupleTypeDecl = tupleTypeDecl;
+ info.initExpr = initExpr;
+ info.primaryVarLayout = primaryVarLayout;
+ info.tupleTypeLayout = tupleTypeLayout;
+
+ createTupleTypeSecondaryVarDecls(info);
return tupleDecl;
}
@@ -1978,6 +2129,11 @@ struct LoweringVisitor
typeLayout = parameterBlockTypeLayout->elementTypeLayout;
}
+ while (auto arrayTypeLayout = typeLayout.As<ArrayTypeLayout>())
+ {
+ typeLayout = arrayTypeLayout->elementTypeLayout;
+ }
+
if (auto structTypeLayout = typeLayout.As<StructTypeLayout>())
{
return structTypeLayout;
@@ -2020,7 +2176,7 @@ struct LoweringVisitor
{
auto loweredType = lowerType(decl->Type);
- if (auto tupleTypeMod = isTupleType(loweredType))
+ if (auto tupleTypeMod = isTupleTypeOrArrayOfTupleType(loweredType))
{
auto varLayout = tryToFindLayout(decl).As<VarLayout>();
@@ -2051,7 +2207,7 @@ struct LoweringVisitor
auto varLayout = tryToFindLayout(decl).As<VarLayout>();
auto elementType = bufferType->elementType;
- if (auto elementTupleTypeMod = isTupleType(elementType))
+ if (auto elementTupleTypeMod = isTupleTypeOrArrayOfTupleType(elementType))
{
auto tupleDecl = createTupleTypeVarDecls(
loweredDeclClass,
diff --git a/source/slang/syntax.h b/source/slang/syntax.h
index 83b2f5801..3f1c47fb9 100644
--- a/source/slang/syntax.h
+++ b/source/slang/syntax.h
@@ -244,7 +244,7 @@ namespace Slang
: createFunc(createFunc)
{}
- void* createInstanceImpl()
+ void* createInstanceImpl() const
{
return createFunc ? createFunc() : nullptr;
}
@@ -271,7 +271,7 @@ namespace Slang
{
}
- T* createInstance()
+ T* createInstance() const
{
return (T*)createInstanceImpl();
}