diff options
| -rw-r--r-- | source/slang/slang-ast-support-types.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-check-conversion.cpp | 62 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 34 | ||||
| -rw-r--r-- | source/slang/slang-syntax.h | 6 | ||||
| -rw-r--r-- | tests/language-feature/inheritance/derived-struct-init-list.slang | 42 | ||||
| -rw-r--r-- | tests/language-feature/inheritance/derived-struct-init-list.slang.expected.txt | 4 |
6 files changed, 150 insertions, 0 deletions
diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h index c172ea83c..884668353 100644 --- a/source/slang/slang-ast-support-types.h +++ b/source/slang/slang-ast-support-types.h @@ -1002,6 +1002,8 @@ namespace Slang /// True if non empty (equivalent to getCount() != 0 but faster) bool isNonEmpty() const { return isFilterNonEmpty<T>(m_filterStyle, m_decls.begin(), m_decls.end()); } + DeclRef<T> getFirstOrNull() { return isEmpty() ? DeclRef<T>() : (*this)[0]; } + DeclRef<T> operator[](Index index) const { Decl*const* decl = getFilterCursorByIndex<T>(m_filterStyle, m_decls.begin(), m_decls.end(), index); diff --git a/source/slang/slang-check-conversion.cpp b/source/slang/slang-check-conversion.cpp index 8b60b2725..b6c7069a2 100644 --- a/source/slang/slang-check-conversion.cpp +++ b/source/slang/slang-check-conversion.cpp @@ -151,6 +151,44 @@ namespace Slang ioInitArgIndex); } + DeclRefType* findBaseStructType(ASTBuilder* astBuilder, DeclRef<StructDecl> const& structTypeDeclRef) + { + auto inheritanceDecl = getMembersOfType<InheritanceDecl>(structTypeDeclRef).getFirstOrNull(); + if(!inheritanceDecl) + return nullptr; + + auto baseType = getBaseType(astBuilder, inheritanceDecl); + auto baseDeclRefType = as<DeclRefType>(baseType); + if(!baseDeclRefType) + return nullptr; + + auto baseDeclRef = baseDeclRefType->declRef; + auto baseStructDeclRef = baseDeclRef.as<StructDecl>(); + if(!baseStructDeclRef) + return nullptr; + + return baseDeclRefType; + } + + DeclRef<StructDecl> findBaseStructDeclRef(ASTBuilder* astBuilder, DeclRef<StructDecl> const& structTypeDeclRef) + { + auto inheritanceDecl = getMembersOfType<InheritanceDecl>(structTypeDeclRef).getFirstOrNull(); + if (!inheritanceDecl) + return DeclRef<StructDecl>(); + + auto baseType = getBaseType(astBuilder, inheritanceDecl); + auto baseDeclRefType = as<DeclRefType>(baseType); + if (!baseDeclRefType) + return DeclRef<StructDecl>(); + + auto baseDeclRef = baseDeclRefType->declRef; + auto baseStructDeclRef = baseDeclRef.as<StructDecl>(); + if (!baseStructDeclRef) + return DeclRef<StructDecl>(); + + return baseStructDeclRef; + } + bool SemanticsVisitor::_readAggregateValueFromInitializerList( Type* inToType, Expr** outToExpr, @@ -375,6 +413,30 @@ namespace Slang if(auto toStructDeclRef = toTypeDeclRef.as<StructDecl>()) { // Trying to initialize a `struct` type given an initializer list. + // + // Before we iterate over the fields, we want to check if this struct + // inherits from another `struct` type. If so, we want to read + // an initializer for that base type first. + // + if (auto baseStructType = findBaseStructType(m_astBuilder, toStructDeclRef)) + { + Expr* coercedArg = nullptr; + bool argResult = _readValueFromInitializerList( + baseStructType, + outToExpr ? &coercedArg : nullptr, + fromInitializerListExpr, + ioArgIndex); + + // No point in trying further if any argument fails + if (!argResult) + return false; + + if (coercedArg) + { + coercedArgs.add(coercedArg); + } + } + // We will go through the fields in order and try to match them // up with initializer arguments. // diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index ab488a41f..d2d15735c 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -2961,6 +2961,16 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> if (auto aggTypeDeclRef = declRef.as<AggTypeDecl>()) { List<IRInst*> args; + + if (auto structTypeDeclRef = aggTypeDeclRef.as<StructDecl>()) + { + if (auto baseStructType = findBaseStructType(getASTBuilder(), structTypeDeclRef)) + { + auto irBaseVal = getSimpleVal(context, getDefaultVal(baseStructType)); + args.add(irBaseVal); + } + } + for (auto ff : getMembersOfType<VarDecl>(aggTypeDeclRef, MemberFilterStyle::Instance)) { auto irFieldVal = getSimpleVal(context, getDefaultVal(ff)); @@ -3082,6 +3092,30 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> if (auto aggTypeDeclRef = declRef.as<AggTypeDecl>()) { UInt argCounter = 0; + + // If the type is a structure type that inherits from another + // structure type, then we need to treat the base type as + // an implicit first field. + // + if(auto structTypeDeclRef = aggTypeDeclRef.as<StructDecl>()) + { + if (auto baseStructType = findBaseStructType(getASTBuilder(), structTypeDeclRef)) + { + UInt argIndex = argCounter++; + if (argIndex < argCount) + { + auto argExpr = expr->args[argIndex]; + LoweredValInfo argVal = lowerRValueExpr(context, argExpr); + args.add(getSimpleVal(context, argVal)); + } + else + { + auto irDefaultValue = getSimpleVal(context, getDefaultVal(baseStructType)); + args.add(irDefaultValue); + } + } + } + for (auto ff : getMembersOfType<VarDecl>(aggTypeDeclRef, MemberFilterStyle::Instance)) { UInt argIndex = argCounter++; diff --git a/source/slang/slang-syntax.h b/source/slang/slang-syntax.h index a23ad224e..c144ceb70 100644 --- a/source/slang/slang-syntax.h +++ b/source/slang/slang-syntax.h @@ -136,6 +136,12 @@ namespace Slang return getMembersOfType<VarDecl>(declRef, filterStyle); } + /// If the given `structTypeDeclRef` inherits from another struct type, return that base type + DeclRefType* findBaseStructType(ASTBuilder* astBuilder, DeclRef<StructDecl> const& structTypeDeclRef); + + /// If the given `structTypeDeclRef` inherits from another struct type, return that base struct decl + DeclRef<StructDecl> findBaseStructDeclRef(ASTBuilder* astBuilder, DeclRef<StructDecl> const& structTypeDeclRef); + inline Type* getTagType(ASTBuilder* astBuilder, DeclRef<EnumDecl> const& declRef) { return declRef.substitute(astBuilder, declRef.getDecl()->tagType); diff --git a/tests/language-feature/inheritance/derived-struct-init-list.slang b/tests/language-feature/inheritance/derived-struct-init-list.slang new file mode 100644 index 000000000..edcb685e6 --- /dev/null +++ b/tests/language-feature/inheritance/derived-struct-init-list.slang @@ -0,0 +1,42 @@ +// derived-struct-init-list.slang + +//TEST(compute):COMPARE_COMPUTE: + +// Test that use of an initializer list (especially +// an empty initializer list) is still possible +// when using `struct` inheritance. + +struct Base +{ + int a = 1; +} + +struct Derived : Base +{ + int b = 2; + + void write(inout int val) { val = val*0x100 + a*0x10 + b; } +} + +int test(int val) +{ + Derived x = {}; + Derived y = { val, val+1 }; + + int result = 1; + x.write(result); + y.write(result); + return result; +} + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<int> outputBuffer; + +[numthreads(4, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + uint tid = dispatchThreadID.x; + int inVal = tid; + int outVal = test(inVal); + outputBuffer[tid] = outVal; +} diff --git a/tests/language-feature/inheritance/derived-struct-init-list.slang.expected.txt b/tests/language-feature/inheritance/derived-struct-init-list.slang.expected.txt new file mode 100644 index 000000000..373286856 --- /dev/null +++ b/tests/language-feature/inheritance/derived-struct-init-list.slang.expected.txt @@ -0,0 +1,4 @@ +11201 +11212 +11223 +11234 |
