diff options
| author | Yong He <yonghe@outlook.com> | 2020-08-21 01:10:45 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2020-08-21 01:10:45 -0700 |
| commit | a8bc5983eae60ca37e853041e989a654c1247876 (patch) | |
| tree | 330cae10c2c24ba14ca726b61c576d9f362f5b8e | |
| parent | 11748a75e66c2bd3fa7ef7635fd35363465f599c (diff) | |
Allow calling a generic function with an existential value (dynamic dispatch) (#1508)
* Allow calling a generic function with an existential value (dynamic dispatch).
* Fixes per review comments.
* Clean up implementation by having `openExistential` return `ExtractExistentialType` instead of a DeclRef to the interface with a `ThisTypeSubstitution`.
* More cleanups
Co-authored-by: Tim Foley <tfoleyNV@users.noreply.github.com>
Co-authored-by: Yong He <yhe@nvidia.com>
| -rw-r--r-- | source/slang/slang-ast-type.cpp | 6 | ||||
| -rw-r--r-- | source/slang/slang-ast-type.h | 1 | ||||
| -rw-r--r-- | source/slang/slang-check-conformance.cpp | 24 | ||||
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 4 | ||||
| -rw-r--r-- | source/slang/slang-check-overload.cpp | 5 | ||||
| -rw-r--r-- | source/slang/slang-lookup.cpp | 4 | ||||
| -rw-r--r-- | tests/compute/dynamic-dispatch-10.slang | 47 | ||||
| -rw-r--r-- | tests/compute/dynamic-dispatch-10.slang.expected.txt | 4 |
8 files changed, 91 insertions, 4 deletions
diff --git a/source/slang/slang-ast-type.cpp b/source/slang/slang-ast-type.cpp index f5feb4d64..55bb9a1e5 100644 --- a/source/slang/slang-ast-type.cpp +++ b/source/slang/slang-ast-type.cpp @@ -676,7 +676,7 @@ bool ExtractExistentialType::_equalsImplOverride(Type* type) HashCode ExtractExistentialType::_getHashCodeOverride() { - return declRef.getHashCode(); + return combineHash(declRef.getHashCode(), interfaceDeclRef.getHashCode()); } Type* ExtractExistentialType::_createCanonicalTypeOverride() @@ -688,13 +688,15 @@ Val* ExtractExistentialType::_substituteImplOverride(ASTBuilder* astBuilder, Sub { int diff = 0; auto substDeclRef = declRef.substituteImpl(astBuilder, subst, &diff); + auto interfaceSubstDeclRef = interfaceDeclRef.substituteImpl(astBuilder, subst, &diff); if (!diff) return this; (*ioDiff)++; ExtractExistentialType* substValue = astBuilder->create<ExtractExistentialType>(); - substValue->declRef = declRef; + substValue->declRef = substDeclRef; + substValue->interfaceDeclRef = interfaceSubstDeclRef; return substValue; } diff --git a/source/slang/slang-ast-type.h b/source/slang/slang-ast-type.h index 6673fa426..4c1c07ba3 100644 --- a/source/slang/slang-ast-type.h +++ b/source/slang/slang-ast-type.h @@ -611,6 +611,7 @@ class ExtractExistentialType : public Type SLANG_CLASS(ExtractExistentialType) DeclRef<VarDeclBase> declRef; + DeclRef<InterfaceDecl> interfaceDeclRef; // Overrides should be public so base classes can access String _toStringOverride(); diff --git a/source/slang/slang-check-conformance.cpp b/source/slang/slang-check-conformance.cpp index 4dcfb3065..e1904dc71 100644 --- a/source/slang/slang-check-conformance.cpp +++ b/source/slang/slang-check-conformance.cpp @@ -268,6 +268,30 @@ namespace Slang } } } + else if (auto extractExistentialType = as<ExtractExistentialType>(subType)) + { + // An ExtractExistentialType from an existential value of type I + // is a subtype of I. + // We need to check and make sure the interface type of the `ExtractExistentialType` + // is equal to `superType`. + auto interfaceDeclRef = extractExistentialType->interfaceDeclRef; + auto thisTypeSubst = findThisTypeSubstitution(interfaceDeclRef.substitutions.substitutions, interfaceDeclRef.getDecl()); + SLANG_ASSERT(thisTypeSubst && thisTypeSubst == interfaceDeclRef.substitutions.substitutions); + // The interfaceDeclRef in `extractExistentialType` contains a `ThisTypeSubstitution` + // to allow member lookup to return correct substituted types. Here we just need + // to know if that interface is the same as the superType, so we need to exclude + // the `ThisTypeSubstitution` from comparison. + interfaceDeclRef.substitutions.substitutions = thisTypeSubst->outer; + if (interfaceDeclRef.equals(superTypeDeclRef)) + { + if (outWitness) + { + *outWitness = thisTypeSubst->witness; + } + return true; + } + return false; + } else if(auto taggedUnionType = as<TaggedUnionType>(subType)) { // A tagged union type conforms to an interface if all of diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 63367fa97..0cc3a55c5 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -116,11 +116,11 @@ namespace Slang openedThisType->witness = openedWitness; DeclRef<InterfaceDecl> substDeclRef = DeclRef<InterfaceDecl>(interfaceDecl, openedThisType); - auto substDeclRefType = DeclRefType::create(m_astBuilder, substDeclRef); + openedType->interfaceDeclRef = substDeclRef; ExtractExistentialValueExpr* openedValue = m_astBuilder->create<ExtractExistentialValueExpr>(); openedValue->declRef = varDeclRef; - openedValue->type = QualType(substDeclRefType); + openedValue->type = QualType(openedType); return openedValue; }); diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp index 94638e473..df5424659 100644 --- a/source/slang/slang-check-overload.cpp +++ b/source/slang/slang-check-overload.cpp @@ -1517,6 +1517,11 @@ namespace Slang return CreateErrorExpr(expr); } + for (auto& arg : expr->arguments) + { + arg = maybeOpenExistential(arg); + } + context.originalExpr = expr; context.funcLoc = funcExpr->loc; diff --git a/source/slang/slang-lookup.cpp b/source/slang/slang-lookup.cpp index c9b922415..a17ba0ba2 100644 --- a/source/slang/slang-lookup.cpp +++ b/source/slang/slang-lookup.cpp @@ -587,6 +587,10 @@ static void _lookUpMembersInSuperTypeImpl( _lookUpMembersInSuperTypeDeclImpl(astBuilder, name, leafType, superType, leafIsSuperWitness, declRef, request, ioResult, inBreadcrumbs); } + if (auto extractExistentialType = as<ExtractExistentialType>(superType)) + { + _lookUpMembersInSuperTypeDeclImpl(astBuilder, name, leafType, superType, leafIsSuperWitness, extractExistentialType->interfaceDeclRef, request, ioResult, inBreadcrumbs); + } } /// Perform lookup for `name` in the context of `type`. diff --git a/tests/compute/dynamic-dispatch-10.slang b/tests/compute/dynamic-dispatch-10.slang new file mode 100644 index 000000000..3e1848186 --- /dev/null +++ b/tests/compute/dynamic-dispatch-10.slang @@ -0,0 +1,47 @@ +//TEST(compute):COMPARE_COMPUTE:-cpu -xslang -allow-dynamic-code +//DISABLE_TEST(compute):COMPARE_COMPUTE:-cuda -xslang -allow-dynamic-code + +// Test dynamic dispatch code gen for specializing a generic with +// an existential value. + +[anyValueSize(16)] +interface IInterface +{ + int Compute(int inVal); +}; + +int GenericCompute0(IInterface obj, int inVal) +{ + return GenericCompute1(obj, obj, inVal); +} + +int GenericCompute1<T:IInterface>(T obj, IInterface obj2, int inVal) +{ + return obj.Compute(inVal) + obj2.Compute(inVal); +} + + +struct Impl : IInterface +{ + int base; + int Compute(int inVal) { return base + inVal * inVal; } +}; + +int test(int inVal) +{ + Impl obj; + obj.base = 1; + return GenericCompute0(obj, inVal); +} + +//TEST_INPUT:ubuffer(data=[0 1 2 3], stride=4):out,name=outputBuffer +RWStructuredBuffer<int> outputBuffer : register(u0); + +[numthreads(4, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + uint tid = dispatchThreadID.x; + int inVal = outputBuffer[tid]; + int outVal = test(inVal); + outputBuffer[tid] = outVal; +} diff --git a/tests/compute/dynamic-dispatch-10.slang.expected.txt b/tests/compute/dynamic-dispatch-10.slang.expected.txt new file mode 100644 index 000000000..70a793021 --- /dev/null +++ b/tests/compute/dynamic-dispatch-10.slang.expected.txt @@ -0,0 +1,4 @@ +2 +4 +A +14 |
