summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/slang-ast-base.h7
-rw-r--r--source/slang/slang-ast-builder.cpp10
-rw-r--r--source/slang/slang-ast-decl-ref.cpp8
-rw-r--r--tests/autodiff/self-differential-type-equality-constraint.slang35
4 files changed, 54 insertions, 6 deletions
diff --git a/source/slang/slang-ast-base.h b/source/slang/slang-ast-base.h
index 828bbdb5c..fe3554f84 100644
--- a/source/slang/slang-ast-base.h
+++ b/source/slang/slang-ast-base.h
@@ -673,12 +673,7 @@ class DeclRefBase : public Val
SourceLoc getNameLoc() const;
SourceLoc getLoc() const;
DeclRefBase* getParent();
- String toString() const
- {
- StringBuilder sb;
- const_cast<DeclRefBase*>(this)->toText(sb);
- return sb.produceString();
- }
+ String toString() const;
DeclRefBase* getBase();
void toText(StringBuilder& out);
};
diff --git a/source/slang/slang-ast-builder.cpp b/source/slang/slang-ast-builder.cpp
index f7304f308..be71fc334 100644
--- a/source/slang/slang-ast-builder.cpp
+++ b/source/slang/slang-ast-builder.cpp
@@ -936,6 +936,11 @@ top:
{
return bIsSubtypeOfCWitness;
}
+ else if (auto declAIsSubtypeOfBWitness = as<DeclaredSubtypeWitness>(aIsSubtypeOfBWitness))
+ {
+ if (declAIsSubtypeOfBWitness->isEquality())
+ return bIsSubtypeOfCWitness;
+ }
// Similarly, if `b == c`, then the `a <: b` witness is a witness for `a <: c`
//
@@ -943,6 +948,11 @@ top:
{
return aIsSubtypeOfBWitness;
}
+ else if (auto declBIsSubtypeOfCWitness = as<DeclaredSubtypeWitness>(bIsSubtypeOfCWitness))
+ {
+ if (declBIsSubtypeOfCWitness->isEquality())
+ return declBIsSubtypeOfCWitness;
+ }
// HACK: There is downstream code generation logic that assumes that
// a `TransitiveSubtypeWitness` will never have a transitive witness
diff --git a/source/slang/slang-ast-decl-ref.cpp b/source/slang/slang-ast-decl-ref.cpp
index 7eec32a3a..c054ef994 100644
--- a/source/slang/slang-ast-decl-ref.cpp
+++ b/source/slang/slang-ast-decl-ref.cpp
@@ -426,6 +426,14 @@ SourceLoc DeclRefBase::getLoc() const
return getDecl()->loc;
}
+// Keep this function here for better debuggin purpose
+String DeclRefBase::toString() const
+{
+ StringBuilder sb;
+ const_cast<DeclRefBase*>(this)->toText(sb);
+ return sb.produceString();
+}
+
DeclRefBase* DeclRefBase::getParent()
{
auto astBuilder = getCurrentASTBuilder();
diff --git a/tests/autodiff/self-differential-type-equality-constraint.slang b/tests/autodiff/self-differential-type-equality-constraint.slang
new file mode 100644
index 000000000..0baaa07d6
--- /dev/null
+++ b/tests/autodiff/self-differential-type-equality-constraint.slang
@@ -0,0 +1,35 @@
+//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-vk -compute -shaderobj -output-using-type
+//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=BUFFER):-cuda -compute -shaderobj -output-using-type
+
+interface IV : IDifferentiablePtrType{
+ int get();
+}
+
+struct V : IV
+{
+ typealias Differential = This;
+
+ int get()
+ {
+ return 12;
+ }
+}
+
+int g<T:IV>(DifferentialPtrPair<T> obj)
+ where T.Differential == T
+{
+ return obj.d.get();
+}
+
+//TEST_INPUT:ubuffer(data=[0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<int> outputBuffer;
+[shader("compute")]
+void computeMain()
+{
+ V v = {};
+ DifferentialPtrPair<V> p = {v, v};
+
+ // BUFFER: 12
+ outputBuffer[0] = g(p);
+}