summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/slang-ast-support-types.h2
-rw-r--r--source/slang/slang-check-conversion.cpp62
-rw-r--r--source/slang/slang-lower-to-ir.cpp34
-rw-r--r--source/slang/slang-syntax.h6
-rw-r--r--tests/language-feature/inheritance/derived-struct-init-list.slang42
-rw-r--r--tests/language-feature/inheritance/derived-struct-init-list.slang.expected.txt4
6 files changed, 150 insertions, 0 deletions
diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h
index c172ea83c..884668353 100644
--- a/source/slang/slang-ast-support-types.h
+++ b/source/slang/slang-ast-support-types.h
@@ -1002,6 +1002,8 @@ namespace Slang
/// True if non empty (equivalent to getCount() != 0 but faster)
bool isNonEmpty() const { return isFilterNonEmpty<T>(m_filterStyle, m_decls.begin(), m_decls.end()); }
+ DeclRef<T> getFirstOrNull() { return isEmpty() ? DeclRef<T>() : (*this)[0]; }
+
DeclRef<T> operator[](Index index) const
{
Decl*const* decl = getFilterCursorByIndex<T>(m_filterStyle, m_decls.begin(), m_decls.end(), index);
diff --git a/source/slang/slang-check-conversion.cpp b/source/slang/slang-check-conversion.cpp
index 8b60b2725..b6c7069a2 100644
--- a/source/slang/slang-check-conversion.cpp
+++ b/source/slang/slang-check-conversion.cpp
@@ -151,6 +151,44 @@ namespace Slang
ioInitArgIndex);
}
+ DeclRefType* findBaseStructType(ASTBuilder* astBuilder, DeclRef<StructDecl> const& structTypeDeclRef)
+ {
+ auto inheritanceDecl = getMembersOfType<InheritanceDecl>(structTypeDeclRef).getFirstOrNull();
+ if(!inheritanceDecl)
+ return nullptr;
+
+ auto baseType = getBaseType(astBuilder, inheritanceDecl);
+ auto baseDeclRefType = as<DeclRefType>(baseType);
+ if(!baseDeclRefType)
+ return nullptr;
+
+ auto baseDeclRef = baseDeclRefType->declRef;
+ auto baseStructDeclRef = baseDeclRef.as<StructDecl>();
+ if(!baseStructDeclRef)
+ return nullptr;
+
+ return baseDeclRefType;
+ }
+
+ DeclRef<StructDecl> findBaseStructDeclRef(ASTBuilder* astBuilder, DeclRef<StructDecl> const& structTypeDeclRef)
+ {
+ auto inheritanceDecl = getMembersOfType<InheritanceDecl>(structTypeDeclRef).getFirstOrNull();
+ if (!inheritanceDecl)
+ return DeclRef<StructDecl>();
+
+ auto baseType = getBaseType(astBuilder, inheritanceDecl);
+ auto baseDeclRefType = as<DeclRefType>(baseType);
+ if (!baseDeclRefType)
+ return DeclRef<StructDecl>();
+
+ auto baseDeclRef = baseDeclRefType->declRef;
+ auto baseStructDeclRef = baseDeclRef.as<StructDecl>();
+ if (!baseStructDeclRef)
+ return DeclRef<StructDecl>();
+
+ return baseStructDeclRef;
+ }
+
bool SemanticsVisitor::_readAggregateValueFromInitializerList(
Type* inToType,
Expr** outToExpr,
@@ -375,6 +413,30 @@ namespace Slang
if(auto toStructDeclRef = toTypeDeclRef.as<StructDecl>())
{
// Trying to initialize a `struct` type given an initializer list.
+ //
+ // Before we iterate over the fields, we want to check if this struct
+ // inherits from another `struct` type. If so, we want to read
+ // an initializer for that base type first.
+ //
+ if (auto baseStructType = findBaseStructType(m_astBuilder, toStructDeclRef))
+ {
+ Expr* coercedArg = nullptr;
+ bool argResult = _readValueFromInitializerList(
+ baseStructType,
+ outToExpr ? &coercedArg : nullptr,
+ fromInitializerListExpr,
+ ioArgIndex);
+
+ // No point in trying further if any argument fails
+ if (!argResult)
+ return false;
+
+ if (coercedArg)
+ {
+ coercedArgs.add(coercedArg);
+ }
+ }
+
// We will go through the fields in order and try to match them
// up with initializer arguments.
//
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index ab488a41f..d2d15735c 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -2961,6 +2961,16 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
if (auto aggTypeDeclRef = declRef.as<AggTypeDecl>())
{
List<IRInst*> args;
+
+ if (auto structTypeDeclRef = aggTypeDeclRef.as<StructDecl>())
+ {
+ if (auto baseStructType = findBaseStructType(getASTBuilder(), structTypeDeclRef))
+ {
+ auto irBaseVal = getSimpleVal(context, getDefaultVal(baseStructType));
+ args.add(irBaseVal);
+ }
+ }
+
for (auto ff : getMembersOfType<VarDecl>(aggTypeDeclRef, MemberFilterStyle::Instance))
{
auto irFieldVal = getSimpleVal(context, getDefaultVal(ff));
@@ -3082,6 +3092,30 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
if (auto aggTypeDeclRef = declRef.as<AggTypeDecl>())
{
UInt argCounter = 0;
+
+ // If the type is a structure type that inherits from another
+ // structure type, then we need to treat the base type as
+ // an implicit first field.
+ //
+ if(auto structTypeDeclRef = aggTypeDeclRef.as<StructDecl>())
+ {
+ if (auto baseStructType = findBaseStructType(getASTBuilder(), structTypeDeclRef))
+ {
+ UInt argIndex = argCounter++;
+ if (argIndex < argCount)
+ {
+ auto argExpr = expr->args[argIndex];
+ LoweredValInfo argVal = lowerRValueExpr(context, argExpr);
+ args.add(getSimpleVal(context, argVal));
+ }
+ else
+ {
+ auto irDefaultValue = getSimpleVal(context, getDefaultVal(baseStructType));
+ args.add(irDefaultValue);
+ }
+ }
+ }
+
for (auto ff : getMembersOfType<VarDecl>(aggTypeDeclRef, MemberFilterStyle::Instance))
{
UInt argIndex = argCounter++;
diff --git a/source/slang/slang-syntax.h b/source/slang/slang-syntax.h
index a23ad224e..c144ceb70 100644
--- a/source/slang/slang-syntax.h
+++ b/source/slang/slang-syntax.h
@@ -136,6 +136,12 @@ namespace Slang
return getMembersOfType<VarDecl>(declRef, filterStyle);
}
+ /// If the given `structTypeDeclRef` inherits from another struct type, return that base type
+ DeclRefType* findBaseStructType(ASTBuilder* astBuilder, DeclRef<StructDecl> const& structTypeDeclRef);
+
+ /// If the given `structTypeDeclRef` inherits from another struct type, return that base struct decl
+ DeclRef<StructDecl> findBaseStructDeclRef(ASTBuilder* astBuilder, DeclRef<StructDecl> const& structTypeDeclRef);
+
inline Type* getTagType(ASTBuilder* astBuilder, DeclRef<EnumDecl> const& declRef)
{
return declRef.substitute(astBuilder, declRef.getDecl()->tagType);
diff --git a/tests/language-feature/inheritance/derived-struct-init-list.slang b/tests/language-feature/inheritance/derived-struct-init-list.slang
new file mode 100644
index 000000000..edcb685e6
--- /dev/null
+++ b/tests/language-feature/inheritance/derived-struct-init-list.slang
@@ -0,0 +1,42 @@
+// derived-struct-init-list.slang
+
+//TEST(compute):COMPARE_COMPUTE:
+
+// Test that use of an initializer list (especially
+// an empty initializer list) is still possible
+// when using `struct` inheritance.
+
+struct Base
+{
+ int a = 1;
+}
+
+struct Derived : Base
+{
+ int b = 2;
+
+ void write(inout int val) { val = val*0x100 + a*0x10 + b; }
+}
+
+int test(int val)
+{
+ Derived x = {};
+ Derived y = { val, val+1 };
+
+ int result = 1;
+ x.write(result);
+ y.write(result);
+ return result;
+}
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<int> outputBuffer;
+
+[numthreads(4, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ uint tid = dispatchThreadID.x;
+ int inVal = tid;
+ int outVal = test(inVal);
+ outputBuffer[tid] = outVal;
+}
diff --git a/tests/language-feature/inheritance/derived-struct-init-list.slang.expected.txt b/tests/language-feature/inheritance/derived-struct-init-list.slang.expected.txt
new file mode 100644
index 000000000..373286856
--- /dev/null
+++ b/tests/language-feature/inheritance/derived-struct-init-list.slang.expected.txt
@@ -0,0 +1,4 @@
+11201
+11212
+11223
+11234