diff options
| author | Tim Foley <tfoley@nvidia.com> | 2017-07-17 13:32:20 -0700 |
|---|---|---|
| committer | Tim Foley <tfoley@nvidia.com> | 2017-07-17 13:35:22 -0700 |
| commit | 453a9ca07417bbc17294267c5e44843d16e93c50 (patch) | |
| tree | 3c07ad5576737423cd407772a7d23748eb67f090 /source/slang | |
| parent | 77e3c3bfb1f77ec04cd8e63a676bfa3e2ae2f998 (diff) | |
Handle arrays when scalarizing "resources in structs"
The basic idea is that an array of `struct`s will get scalarized into per-field arrays (for any fields that need to be scalarized). So given:
struct Foo { float x; Texture2D t; };
cbuffer C { Foo foo[4]; }
We'll get output like:
struct Foo { float x; };
cbuffer C { Foo foo[4]; }
Texture2D C_foo_t[4];
(Of course the output would also be translated over to GLSL, but I'm only concerned about this one transformation here).
Diffstat (limited to 'source/slang')
| -rw-r--r-- | source/slang/lower.cpp | 240 | ||||
| -rw-r--r-- | source/slang/syntax.h | 4 |
2 files changed, 200 insertions, 44 deletions
diff --git a/source/slang/lower.cpp b/source/slang/lower.cpp index 98f6d8273..b6a19ab34 100644 --- a/source/slang/lower.cpp +++ b/source/slang/lower.cpp @@ -616,7 +616,7 @@ struct LoweringVisitor result->tupleElements.Add(elem); } - return result; +return result; } RefPtr<ExpressionSyntaxNode> visitVarExpressionSyntaxNode( @@ -708,6 +708,84 @@ struct LoweringVisitor return loweredExpr; } + RefPtr<ExpressionType> getSubscripResultType( + RefPtr<ExpressionType> type) + { + if (auto arrayType = type->As<ArrayExpressionType>()) + { + return arrayType->BaseType; + } + return nullptr; + } + + RefPtr<ExpressionSyntaxNode> createSubscriptExpr( + RefPtr<ExpressionSyntaxNode> baseExpr, + RefPtr<ExpressionSyntaxNode> indexExpr) + { + // TODO: This logic ends up duplicating the `indexExpr` + // that was given, without worrying about any side + // effects it might contain. That needs to be fixed. + + if (auto baseTuple = baseExpr.As<TupleExpr>()) + { + auto loweredExpr = new TupleExpr(); + loweredExpr->Type.type = getSubscripResultType(baseExpr->Type.type); + + if (auto basePrimary = baseTuple->primaryExpr) + { + loweredExpr->primaryExpr = createSubscriptExpr( + basePrimary, + indexExpr); + } + for (auto elem : baseTuple->tupleElements) + { + TupleExpr::Element loweredElem; + loweredElem.tupleFieldDeclRef = elem.tupleFieldDeclRef; + loweredElem.expr = createSubscriptExpr( + elem.expr, + indexExpr); + + loweredExpr->tupleElements.Add(loweredElem); + } + + return loweredExpr; + } + else + { + // Default case: just reconstrut a subscript expr + auto loweredExpr = new IndexExpressionSyntaxNode(); + + loweredExpr->Type.type = getSubscripResultType(baseExpr->Type.type); + + loweredExpr->BaseExpression = baseExpr; + loweredExpr->IndexExpression = indexExpr; + return loweredExpr; + } + } + + RefPtr<ExpressionSyntaxNode> visitIndexExpressionSyntaxNode( + IndexExpressionSyntaxNode* subscriptExpr) + { + auto baseExpr = lowerExpr(subscriptExpr->BaseExpression); + auto indexExpr = lowerExpr(subscriptExpr->IndexExpression); + + // An attempt to subscript a tuple must be turned into a + // tuple of subscript expressions. + if (auto baseTuple = baseExpr.As<TupleExpr>()) + { + return createSubscriptExpr(baseExpr, indexExpr); + } + else + { + // Default case: just reconstrut a subscript expr + RefPtr<IndexExpressionSyntaxNode> loweredExpr = new IndexExpressionSyntaxNode(); + lowerExprCommon(loweredExpr, subscriptExpr); + loweredExpr->BaseExpression = baseExpr; + loweredExpr->IndexExpression = indexExpr; + return loweredExpr; + } + } + void addArgs( InvokeExpressionSyntaxNode* callExpr, RefPtr<ExpressionSyntaxNode> argExpr) @@ -1615,6 +1693,21 @@ struct LoweringVisitor return nullptr; } + ExpressionType* unwrapArray(ExpressionType* inType) + { + auto type = inType; + while (auto arrayType = type->As<ArrayExpressionType>()) + { + type = arrayType->BaseType; + } + return type; + } + + TupleTypeModifier* isTupleTypeOrArrayOfTupleType(ExpressionType* type) + { + return isTupleType(unwrapArray(type)); + } + bool isResourceType(ExpressionType* type) { while (auto arrayType = type->As<ArrayExpressionType>()) @@ -1692,7 +1785,7 @@ struct LoweringVisitor bool isTupleField = false; bool fieldHasAnyNonTupleFields = false; bool fieldHasTupleType = false; - if (auto fieldTupleTypeMod = isTupleType(loweredFieldType)) + if (auto fieldTupleTypeMod = isTupleTypeOrArrayOfTupleType(loweredFieldType)) { isTupleField = true; fieldHasTupleType = true; @@ -1780,19 +1873,60 @@ struct LoweringVisitor return lowerSimpleVarDeclCommon(loweredDecl, decl, loweredType); } + struct TupleTypeSecondaryVarArraySpec + { + TupleTypeSecondaryVarArraySpec* next; + RefPtr<IntVal> elementCount; + }; + + struct TupleSecondaryVarInfo + { + // Parent tuple decl to add the secondary decl into + RefPtr<TupleVarDecl> tupleDecl; + + // Syntax class for declarations to create + SyntaxClass<VarDeclBase> varDeclClass; + + // Name "stem" to use for any actual variables we create + String name; + + // The parent tuple type (or array thereof) we are scalarizing + RefPtr<ExpressionType> tupleType; + + // The actual declaration of the tuple type (which will give us the fields) + DeclRef<AggTypeDecl> tupleTypeDecl; + + // An initializer expression to use for the tuple members + RefPtr<ExpressionSyntaxNode> initExpr; + + // The original layout given to the top-level variable + RefPtr<VarLayout> primaryVarLayout; + + // The computed layout of the tuple type itself + RefPtr<StructTypeLayout> tupleTypeLayout; + + TupleTypeSecondaryVarArraySpec* arraySpecs = nullptr; + }; + void createTupleTypeSecondaryVarDecls( - RefPtr<TupleVarDecl> tupleDecl, - SyntaxClass<VarDeclBase> varDeclClass, - String const& name, - RefPtr<ExpressionType> tupleType, - DeclRef<AggTypeDecl> tupleTypeDecl, - RefPtr<ExpressionSyntaxNode> initExpr, - RefPtr<VarLayout> primaryVarLayout, - RefPtr<StructTypeLayout> tupleTypeLayout) + TupleSecondaryVarInfo const& info) { + if (auto arrayType = info.tupleType->As<ArrayExpressionType>()) + { + TupleTypeSecondaryVarArraySpec arraySpec; + arraySpec.next = info.arraySpecs; + arraySpec.elementCount = arrayType->ArrayLength; + + TupleSecondaryVarInfo subInfo = info; + subInfo.tupleType = arrayType->BaseType; + subInfo.arraySpecs = &arraySpec; + createTupleTypeSecondaryVarDecls(subInfo); + return; + } + // Next, we need to go through the declarations in the aggregate // type, and deal with all of those that should be tuple-ified. - for (auto dd : getMembersOfType<VarDeclBase>(tupleTypeDecl)) + for (auto dd : getMembersOfType<VarDeclBase>(info.tupleTypeDecl)) { if (dd.getDecl()->HasModifier<HLSLStaticModifier>()) continue; @@ -1802,10 +1936,10 @@ struct LoweringVisitor continue; // TODO: need to extract the initializer for this field - assert(!initExpr); + assert(!info.initExpr); RefPtr<ExpressionSyntaxNode> fieldInitExpr; - String fieldName = name + "_" + dd.GetName(); + String fieldName = info.name + "_" + dd.GetName(); auto fieldType = GetType(dd); @@ -1814,11 +1948,11 @@ struct LoweringVisitor assert(originalFieldDecl); RefPtr<VarLayout> fieldLayout; - if(tupleTypeLayout) + if(info.tupleTypeLayout) { - tupleTypeLayout->mapVarToLayout.TryGetValue(originalFieldDecl, fieldLayout); + info.tupleTypeLayout->mapVarToLayout.TryGetValue(originalFieldDecl, fieldLayout); } - if (fieldLayout && primaryVarLayout) + if (fieldLayout && info.primaryVarLayout) { // The layout for a field may need to be adjusted // based on a base offset stored in the primary @@ -1835,7 +1969,7 @@ struct LoweringVisitor bool needsOffset = false; for (auto rr : fieldLayout->resourceInfos) { - if (auto parentInfo = primaryVarLayout->FindResourceInfo(rr.kind)) + if (auto parentInfo = info.primaryVarLayout->FindResourceInfo(rr.kind)) { if (parentInfo->index != 0 || parentInfo->space != 0) { @@ -1858,7 +1992,7 @@ struct LoweringVisitor auto newResInfo = newFieldLayout->findOrAddResourceInfo(resInfo.kind); newResInfo->index = resInfo.index; newResInfo->space = resInfo.space; - if (auto parentInfo = primaryVarLayout->FindResourceInfo(resInfo.kind)) + if (auto parentInfo = info.primaryVarLayout->FindResourceInfo(resInfo.kind)) { newResInfo->index += parentInfo->index; newResInfo->space += parentInfo->space; @@ -1871,29 +2005,44 @@ struct LoweringVisitor } RefPtr<VarDeclBase> fieldVarOrTupleDecl; - if (auto fieldTupleTypeMod = isTupleType(fieldType)) + if (auto fieldTupleTypeMod = isTupleTypeOrArrayOfTupleType(fieldType)) { // If the field is itself a tuple, then recurse RefPtr<TupleVarDecl> fieldTupleDecl = new TupleVarDecl(); + + TupleSecondaryVarInfo fieldInfo; + fieldInfo.tupleDecl = fieldTupleDecl; + fieldInfo.varDeclClass = info.varDeclClass; + fieldInfo.name = fieldName; + fieldInfo.tupleType = fieldType; + fieldInfo.tupleTypeDecl = makeDeclRef(fieldTupleTypeMod->decl); + fieldInfo.initExpr = fieldInitExpr; + fieldInfo.primaryVarLayout = fieldLayout; + fieldInfo.tupleTypeLayout = getBodyStructTypeLayout(fieldLayout ? fieldLayout->typeLayout : nullptr); + fieldInfo.arraySpecs = info.arraySpecs; + fieldTupleDecl->tupleType = fieldTupleTypeMod; - createTupleTypeSecondaryVarDecls( - fieldTupleDecl, - varDeclClass, - fieldName, - fieldType, - makeDeclRef(fieldTupleTypeMod->decl), - fieldInitExpr, - fieldLayout, - getBodyStructTypeLayout(fieldLayout ? fieldLayout->typeLayout : nullptr)); + createTupleTypeSecondaryVarDecls(fieldInfo); fieldVarOrTupleDecl = fieldTupleDecl; } else { // Otherwise the field has a simple type, and we just need to declare the variable here - RefPtr<VarDeclBase> fieldVarDecl = varDeclClass.createInstance(); + + RefPtr<ExpressionType> fieldVarType = fieldType; + for (auto aa = info.arraySpecs; aa; aa = aa->next) + { + RefPtr<ArrayExpressionType> arrayType = new ArrayExpressionType(); + arrayType->BaseType = fieldVarType; + arrayType->ArrayLength = aa->elementCount; + + fieldVarType = arrayType; + } + + RefPtr<VarDeclBase> fieldVarDecl = info.varDeclClass.createInstance(); fieldVarDecl->Name.Content = fieldName; - fieldVarDecl->Type.type = fieldType; + fieldVarDecl->Type.type = fieldVarType; addDecl(fieldVarDecl); @@ -1911,7 +2060,7 @@ struct LoweringVisitor fieldTupleVarMod->tupleField = tupleFieldMod; addModifier(fieldVarOrTupleDecl, fieldTupleVarMod); - tupleDecl->tupleDecls.Add(fieldVarOrTupleDecl); + info.tupleDecl->tupleDecls.Add(fieldVarOrTupleDecl); } } @@ -1955,15 +2104,17 @@ struct LoweringVisitor addDecl(primaryVarDecl); } - createTupleTypeSecondaryVarDecls( - tupleDecl, - varDeclClass, - name, - tupleType, - tupleTypeDecl, - initExpr, - primaryVarLayout, - tupleTypeLayout); + TupleSecondaryVarInfo info; + info.tupleDecl = tupleDecl; + info.varDeclClass = varDeclClass; + info.name = name; + info.tupleType = tupleType; + info.tupleTypeDecl = tupleTypeDecl; + info.initExpr = initExpr; + info.primaryVarLayout = primaryVarLayout; + info.tupleTypeLayout = tupleTypeLayout; + + createTupleTypeSecondaryVarDecls(info); return tupleDecl; } @@ -1978,6 +2129,11 @@ struct LoweringVisitor typeLayout = parameterBlockTypeLayout->elementTypeLayout; } + while (auto arrayTypeLayout = typeLayout.As<ArrayTypeLayout>()) + { + typeLayout = arrayTypeLayout->elementTypeLayout; + } + if (auto structTypeLayout = typeLayout.As<StructTypeLayout>()) { return structTypeLayout; @@ -2020,7 +2176,7 @@ struct LoweringVisitor { auto loweredType = lowerType(decl->Type); - if (auto tupleTypeMod = isTupleType(loweredType)) + if (auto tupleTypeMod = isTupleTypeOrArrayOfTupleType(loweredType)) { auto varLayout = tryToFindLayout(decl).As<VarLayout>(); @@ -2051,7 +2207,7 @@ struct LoweringVisitor auto varLayout = tryToFindLayout(decl).As<VarLayout>(); auto elementType = bufferType->elementType; - if (auto elementTupleTypeMod = isTupleType(elementType)) + if (auto elementTupleTypeMod = isTupleTypeOrArrayOfTupleType(elementType)) { auto tupleDecl = createTupleTypeVarDecls( loweredDeclClass, diff --git a/source/slang/syntax.h b/source/slang/syntax.h index 83b2f5801..3f1c47fb9 100644 --- a/source/slang/syntax.h +++ b/source/slang/syntax.h @@ -244,7 +244,7 @@ namespace Slang : createFunc(createFunc) {} - void* createInstanceImpl() + void* createInstanceImpl() const { return createFunc ? createFunc() : nullptr; } @@ -271,7 +271,7 @@ namespace Slang { } - T* createInstance() + T* createInstance() const { return (T*)createInstanceImpl(); } |
