diff options
| -rw-r--r-- | source/slang/slang-ast-base.h | 7 | ||||
| -rw-r--r-- | source/slang/slang-ast-builder.cpp | 10 | ||||
| -rw-r--r-- | source/slang/slang-ast-decl-ref.cpp | 8 | ||||
| -rw-r--r-- | tests/autodiff/self-differential-type-equality-constraint.slang | 35 |
4 files changed, 54 insertions, 6 deletions
diff --git a/source/slang/slang-ast-base.h b/source/slang/slang-ast-base.h index 828bbdb5c..fe3554f84 100644 --- a/source/slang/slang-ast-base.h +++ b/source/slang/slang-ast-base.h @@ -673,12 +673,7 @@ class DeclRefBase : public Val SourceLoc getNameLoc() const; SourceLoc getLoc() const; DeclRefBase* getParent(); - String toString() const - { - StringBuilder sb; - const_cast<DeclRefBase*>(this)->toText(sb); - return sb.produceString(); - } + String toString() const; DeclRefBase* getBase(); void toText(StringBuilder& out); }; diff --git a/source/slang/slang-ast-builder.cpp b/source/slang/slang-ast-builder.cpp index f7304f308..be71fc334 100644 --- a/source/slang/slang-ast-builder.cpp +++ b/source/slang/slang-ast-builder.cpp @@ -936,6 +936,11 @@ top: { return bIsSubtypeOfCWitness; } + else if (auto declAIsSubtypeOfBWitness = as<DeclaredSubtypeWitness>(aIsSubtypeOfBWitness)) + { + if (declAIsSubtypeOfBWitness->isEquality()) + return bIsSubtypeOfCWitness; + } // Similarly, if `b == c`, then the `a <: b` witness is a witness for `a <: c` // @@ -943,6 +948,11 @@ top: { return aIsSubtypeOfBWitness; } + else if (auto declBIsSubtypeOfCWitness = as<DeclaredSubtypeWitness>(bIsSubtypeOfCWitness)) + { + if (declBIsSubtypeOfCWitness->isEquality()) + return declBIsSubtypeOfCWitness; + } // HACK: There is downstream code generation logic that assumes that // a `TransitiveSubtypeWitness` will never have a transitive witness diff --git a/source/slang/slang-ast-decl-ref.cpp b/source/slang/slang-ast-decl-ref.cpp index 7eec32a3a..c054ef994 100644 --- a/source/slang/slang-ast-decl-ref.cpp +++ b/source/slang/slang-ast-decl-ref.cpp @@ -426,6 +426,14 @@ SourceLoc DeclRefBase::getLoc() const return getDecl()->loc; } +// Keep this function here for better debuggin purpose +String DeclRefBase::toString() const +{ + StringBuilder sb; + const_cast<DeclRefBase*>(this)->toText(sb); + return sb.produceString(); +} + DeclRefBase* DeclRefBase::getParent() { auto astBuilder = getCurrentASTBuilder(); diff --git a/tests/autodiff/self-differential-type-equality-constraint.slang b/tests/autodiff/self-differential-type-equality-constraint.slang new file mode 100644 index 000000000..0baaa07d6 --- /dev/null +++ b/tests/autodiff/self-differential-type-equality-constraint.slang @@ -0,0 +1,35 @@ +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-vk -compute -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-cuda -compute -shaderobj -output-using-type + +interface IV : IDifferentiablePtrType{ + int get(); +} + +struct V : IV +{ + typealias Differential = This; + + int get() + { + return 12; + } +} + +int g<T:IV>(DifferentialPtrPair<T> obj) + where T.Differential == T +{ + return obj.d.get(); +} + +//TEST_INPUT:ubuffer(data=[0], stride=4):out,name=outputBuffer +RWStructuredBuffer<int> outputBuffer; +[shader("compute")] +void computeMain() +{ + V v = {}; + DifferentialPtrPair<V> p = {v, v}; + + // BUFFER: 12 + outputBuffer[0] = g(p); +} |
