diff options
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 29 | ||||
| -rw-r--r-- | source/slang/slang-mangle.cpp | 39 | ||||
| -rw-r--r-- | tests/language-feature/extensions/generic-extension-3.slang | 42 | ||||
| -rw-r--r-- | tests/language-feature/extensions/generic-extension-4.slang | 58 |
4 files changed, 145 insertions, 23 deletions
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 45efca2d9..2271e750e 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -8242,6 +8242,16 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> { subType = DeclRefType::create(context->astBuilder, makeDeclRef(parentDecl)); } + bool isGenericExtension = false; + // Test if we are in a generic extension context + if (parentDecl->parentDecl) + { + auto genDecl = as<GenericDecl>(parentDecl->parentDecl); + if (genDecl) + { + isGenericExtension = true; + } + } // What is the super-type that we have declared we inherit from? Type* superType = inheritanceDecl->base.type; @@ -8304,14 +8314,17 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> { // Construct the mangled name for the witness table, which depends // on the type that is conforming, and the type that it conforms to. - // - // TODO: This approach doesn't really make sense for generic `extension` - // conformances. - auto mangledName = getMangledNameForConformanceWitness( - context->astBuilder, - subType, - superType, - irSubType->getOp()); + String mangledName; + if (isGenericExtension) + { + mangledName = + getMangledNameForConformanceWitness(context->astBuilder, parentDecl, superType); + } + else + { + mangledName = + getMangledNameForConformanceWitness(context->astBuilder, subType, superType); + } // TODO(JS): // Should the mangled name take part in obfuscation if enabled? diff --git a/source/slang/slang-mangle.cpp b/source/slang/slang-mangle.cpp index aa30eef9d..f08ffd75d 100644 --- a/source/slang/slang-mangle.cpp +++ b/source/slang/slang-mangle.cpp @@ -391,6 +391,7 @@ void emitVal(ManglingContext* context, Val* val) void emitQualifiedName(ManglingContext* context, DeclRef<Decl> declRef, bool includeModuleName) { + bool ignoreName = false; if (!includeModuleName) { if (as<ModuleDecl>(declRef)) @@ -445,19 +446,6 @@ void emitQualifiedName(ManglingContext* context, DeclRef<Decl> declRef, bool inc return; } - // Inheritance declarations don't have meaningful names, - // and so we should emit them based on the type - // that is doing the inheriting. - if (auto inheritanceDeclRef = declRef.as<TypeConstraintDecl>()) - { - emit(context, "I"); - emitType(context, getSup(context->astBuilder, inheritanceDeclRef)); - return; - } - - // Similarly, an extension doesn't have a name worth - // emitting, and we should base things on its target - // type instead. if (auto extensionDeclRef = declRef.as<ExtensionDecl>()) { // TODO: as a special case, an "unconditional" extension @@ -471,7 +459,25 @@ void emitQualifiedName(ManglingContext* context, DeclRef<Decl> declRef, bool inc emit(context, "I"); emitType(context, getSup(context->astBuilder, inheritanceDecl)); } - return; + // A non generic extension doesn't have a name worth + // emitting, and we should base things on its target + // type instead. + if (parentGenericDeclRef) + { + ignoreName = true; + } + } + // Inheritance declarations don't have meaningful names, + // and so we should emit them based on the type + // that is doing the inheriting. + else if (auto inheritanceDeclRef = declRef.as<TypeConstraintDecl>()) + { + emit(context, "I"); + emitType(context, getSup(context->astBuilder, inheritanceDeclRef)); + if (parentGenericDeclRef) + { + ignoreName = true; + } } // TODO: we should special case GenericTypeParamDecl and GenericValueParamDecl nodes @@ -480,7 +486,10 @@ void emitQualifiedName(ManglingContext* context, DeclRef<Decl> declRef, bool inc // For each generic parameter, we should assign it a unique ID (i, j), where i is the // nesting level of the generic, and j is the sequential order of the parameter within // its generic parent, and use this 2D ID to refer to such a parameter. - emitName(context, declRef.getName()); + if (!ignoreName) + { + emitName(context, declRef.getName()); + } // Special case: accessors need some way to distinguish themselves // so that a getter/setter/ref-er don't all compile to the same name. diff --git a/tests/language-feature/extensions/generic-extension-3.slang b/tests/language-feature/extensions/generic-extension-3.slang new file mode 100644 index 000000000..358c3facd --- /dev/null +++ b/tests/language-feature/extensions/generic-extension-3.slang @@ -0,0 +1,42 @@ +// Test that multiple functions from different extensions are properly linked. + +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -shaderobj +interface IValuable +{ + int getValue(); +}; + +__generic <Int : __BuiltinIntegerType> +extension Int : IValuable +{ + int getValue() + { + return 0; + } +}; + +__generic <Float : __BuiltinFloatingPointType> +extension Float : IValuable +{ + int getValue() + { + return 1; + } +}; + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<int> outputBuffer; + +[shader("compute")] +[numthreads(1, 1, 1)] +void computeMain() +{ + uint i = 0; + float f = float(i) / float(10); + + // CHECK: 0 + outputBuffer[0] = i.getValue(); + + // CHECK: 1 + outputBuffer[1] = f.getValue(); +}
\ No newline at end of file diff --git a/tests/language-feature/extensions/generic-extension-4.slang b/tests/language-feature/extensions/generic-extension-4.slang new file mode 100644 index 000000000..61e1c1386 --- /dev/null +++ b/tests/language-feature/extensions/generic-extension-4.slang @@ -0,0 +1,58 @@ +// Test that multiple functions from different extensions are properly linked through their respective witness table. + +//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -shaderobj +interface IValuableImpl +{ + int getValue(); +} + +interface IValuable +{ + associatedtype Impl : IValuableImpl; + Impl getImpl(); +}; + +__generic <Int : __BuiltinIntegerType> +extension Int : IValuableImpl +{ + int getValue() + { + return 0; + } +}; + +__generic <Float : __BuiltinFloatingPointType> +extension Float : IValuableImpl +{ + int getValue() + { + return 1; + } +}; + +__generic <ValuableImpl : IValuableImpl> +extension ValuableImpl : IValuable +{ + typealias Impl = ValuableImpl; + ValuableImpl getImpl() + { + return this; + } +}; + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<int> outputBuffer; + +[shader("compute")] +[numthreads(1, 1, 1)] +void computeMain() +{ + uint i = 0; + float f = float(i) / float(10); + + // CHECK: 0 + outputBuffer[0] = i.getImpl().getValue(); + + // CHECK: 1 + outputBuffer[1] = f.getImpl().getValue(); +}
\ No newline at end of file |
