summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/slang-check-decl.cpp44
-rw-r--r--source/slang/slang-check-expr.cpp120
-rw-r--r--source/slang/slang-check-impl.h3
-rw-r--r--source/slang/slang-options.cpp4
-rw-r--r--tests/autodiff/self-differential-generic-type-synthesis.slang36
-rw-r--r--tests/autodiff/self-differential-generic-type-synthesis.slang.expected.txt6
-rw-r--r--tests/autodiff/self-differential-type-synthesis.slang36
-rw-r--r--tests/autodiff/self-differential-type-synthesis.slang.expected.txt6
-rw-r--r--tests/compute/assoctype-nested-lookup.slang44
-rw-r--r--tests/compute/assoctype-nested-lookup.slang.expected.txt2
10 files changed, 274 insertions, 27 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 921bd38e9..0d089874e 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -2144,7 +2144,40 @@ namespace Slang
SLANG_RELEASE_ASSERT(aggTypeDecl);
synth.pushContainerScope(aggTypeDecl);
}
- else
+
+ // If we did not find an existing empty struct, we may need to synthesize one.
+ // But first, we check if the parent type can be used as its own differential type.
+ //
+ if (!aggTypeDecl
+ && as<AggTypeDecl>(context->parentDecl)
+ && canStructBeUsedAsSelfDifferentialType(as<AggTypeDecl>(context->parentDecl)))
+ {
+ // If the parent type can be used as its own differential type, we will create a typealias
+ // to itself as the differential type.
+ //
+ auto assocTypeDef = m_astBuilder->create<TypeDefDecl>();
+ assocTypeDef->nameAndLoc.name = getName("Differential");
+ assocTypeDef->type.type = context->conformingType;
+ assocTypeDef->parentDecl = context->parentDecl;
+ assocTypeDef->setCheckState(DeclCheckState::DefinitionChecked);
+ context->parentDecl->members.add(assocTypeDef);
+
+ markSelfDifferentialMembersOfType(as<AggTypeDecl>(context->parentDecl), context->conformingType);
+
+ if (doesTypeSatisfyAssociatedTypeConstraintRequirement(context->conformingType, requirementDeclRef, witnessTable))
+ {
+ witnessTable->add(requirementDeclRef.getDecl(), RequirementWitness(context->conformingType));
+
+ // Increase the epoch so that future calls to Type::getCanonicalType will return the up-to-date folded types.
+ m_astBuilder->incrementEpoch();
+ return true;
+ }
+
+ // Something went wrong.
+ return false;
+ }
+
+ if (!aggTypeDecl)
{
aggTypeDecl = m_astBuilder->create<StructDecl>();
aggTypeDecl->parentDecl = context->parentDecl;
@@ -5741,6 +5774,15 @@ namespace Slang
{
checkConformance(type, inheritanceDecl, decl);
}
+
+ // Successful conformance checking may have created new witness tables.
+ // Increment epoch to invalidate the cache, so subsequent canonical types are
+ // re-calculated.
+ //
+ // TODO: Is it really necessary to invalidate globally? Maybe there's a way to invalidate only the
+ // types that are affected by these interface decls.
+ //
+ astBuilder->incrementEpoch();
}
}
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index a399ea389..ff74f9a62 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -598,37 +598,60 @@ namespace Slang
{
case BuiltinRequirementKind::DifferentialType:
{
- auto structDecl = m_astBuilder->create<StructDecl>();
- auto conformanceDecl = m_astBuilder->create<InheritanceDecl>();
- conformanceDecl->base.type = m_astBuilder->getDiffInterfaceType();
- conformanceDecl->parentDecl = structDecl;
- structDecl->members.add(conformanceDecl);
- structDecl->parentDecl = parent;
-
- synthesizedDecl = structDecl;
- auto typeDef = m_astBuilder->create<TypeAliasDecl>();
- typeDef->nameAndLoc.name = getName("Differential");
- typeDef->parentDecl = structDecl;
-
- auto synthDeclRef = createDefaultSubstitutionsIfNeeded(m_astBuilder, this, makeDeclRef(structDecl));
-
- typeDef->type.type = DeclRefType::create(m_astBuilder, synthDeclRef);
- structDecl->members.add(typeDef);
+ if (!canStructBeUsedAsSelfDifferentialType(parent))
+ {
+ // Need to create a new struct type for the differential.
+ //
+ auto structDecl = m_astBuilder->create<StructDecl>();
+ auto conformanceDecl = m_astBuilder->create<InheritanceDecl>();
+ conformanceDecl->base.type = m_astBuilder->getDiffInterfaceType();
+ conformanceDecl->parentDecl = structDecl;
+ structDecl->members.add(conformanceDecl);
+ structDecl->parentDecl = parent;
+
+ synthesizedDecl = structDecl;
+ auto typeDef = m_astBuilder->create<TypeAliasDecl>();
+ typeDef->nameAndLoc.name = getName("Differential");
+ typeDef->parentDecl = structDecl;
+
+ auto synthDeclRef = createDefaultSubstitutionsIfNeeded(m_astBuilder, this, makeDeclRef(structDecl));
+
+ typeDef->type.type = DeclRefType::create(m_astBuilder, synthDeclRef);
+ structDecl->members.add(typeDef);
+
+ synthesizedDecl->parentDecl = parent;
+ synthesizedDecl->nameAndLoc.name = item.declRef.getName();
+ synthesizedDecl->loc = parent->loc;
+ parent->members.add(synthesizedDecl);
+ parent->invalidateMemberDictionary();
+
+ // Mark the newly synthesized decl as `ToBeSynthesized` so future checking can differentiate it
+ // from user-provided definitions, and proceed to fill in its definition.
+ auto toBeSynthesized = m_astBuilder->create<ToBeSynthesizedModifier>();
+ addModifier(synthesizedDecl, toBeSynthesized);
+ }
+ else
+ {
+ // There's no need for a new struct decl.
+ // We can simply add a typealias to the existing concrete type.
+ //
+ auto typeDef = m_astBuilder->create<TypeAliasDecl>();
+ typeDef->nameAndLoc.name = item.declRef.getName();
+ typeDef->parentDecl = parent;
+ typeDef->type.type = subType;
+
+ synthesizedDecl = parent;
+
+ parent->members.add(typeDef);
+ parent->invalidateMemberDictionary();
+
+ markSelfDifferentialMembersOfType(parent, subType);
+ }
}
break;
default:
return nullptr;
}
- synthesizedDecl->parentDecl = parent;
- synthesizedDecl->nameAndLoc.name = item.declRef.getName();
- synthesizedDecl->loc = parent->loc;
- parent->members.add(synthesizedDecl);
- parent->invalidateMemberDictionary();
-
- // Mark the newly synthesized decl as `ToBeSynthesized` so future checking can differentiate it
- // from user-provided definitions, and proceed to fill in its definition.
- auto toBeSynthesized = m_astBuilder->create<ToBeSynthesizedModifier>();
- addModifier(synthesizedDecl, toBeSynthesized);
auto synthDeclMemberRef = m_astBuilder->getMemberDeclRef(subType->getDeclRef(), synthesizedDecl);
return ConstructDeclRefExpr(
@@ -1145,6 +1168,51 @@ namespace Slang
return nullptr;
}
+ bool SemanticsVisitor::canStructBeUsedAsSelfDifferentialType(AggTypeDecl *aggTypeDecl)
+ {
+ // A struct can be used as its own differential type if all its members are differentiable
+ // and their differential types are the same as the original types.
+ //
+ bool canBeUsed = true;
+ for (auto member : aggTypeDecl->members)
+ {
+ if (auto varDecl = as<VarDecl>(member))
+ {
+ // Try to get the differential type of the member.
+ Type* diffType = tryGetDifferentialType(getASTBuilder(), varDecl->getType());
+ if (!diffType || !diffType->equals(varDecl->getType()))
+ {
+ canBeUsed = false;
+ break;
+ }
+ }
+ }
+ return canBeUsed;
+ }
+
+ void SemanticsVisitor::markSelfDifferentialMembersOfType(AggTypeDecl *parent, Type* type)
+ {
+ // TODO: Handle extensions.
+ // Add derivative member attributes to all the fields pointing to themselves.
+ for (auto member : parent->getMembersOfType<VarDeclBase>())
+ {
+ auto derivativeMemberModifier = m_astBuilder->create<DerivativeMemberAttribute>();
+ auto fieldLookupExpr = m_astBuilder->create<StaticMemberExpr>();
+ fieldLookupExpr->type.type = member->getType();
+
+ auto baseTypeExpr = m_astBuilder->create<SharedTypeExpr>();
+ baseTypeExpr->base.type = type;
+ auto baseTypeType = m_astBuilder->getOrCreate<TypeType>(type);
+ baseTypeExpr->type.type = baseTypeType;
+ fieldLookupExpr->baseExpression = baseTypeExpr;
+
+ fieldLookupExpr->declRef = makeDeclRef(member);
+
+ derivativeMemberModifier->memberDeclRef = fieldLookupExpr;
+ addModifier(member, derivativeMemberModifier);
+ }
+ }
+
Type* SemanticsVisitor::getDifferentialType(ASTBuilder* builder, Type* type, SourceLoc loc)
{
auto result = tryGetDifferentialType(builder, type);
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index d90c3c4b0..fc87c680b 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -1332,6 +1332,9 @@ namespace Slang
Type* getDifferentialType(ASTBuilder* builder, Type* type, SourceLoc loc);
Type* tryGetDifferentialType(ASTBuilder* builder, Type* type);
+ // Helper function to check if a struct can be used as its own differential type.
+ bool canStructBeUsedAsSelfDifferentialType(AggTypeDecl *aggTypeDecl);
+ void markSelfDifferentialMembersOfType(AggTypeDecl *parent, Type* type);
public:
diff --git a/source/slang/slang-options.cpp b/source/slang/slang-options.cpp
index 857f4272c..805ea0fff 100644
--- a/source/slang/slang-options.cpp
+++ b/source/slang/slang-options.cpp
@@ -1707,6 +1707,10 @@ SlangResult OptionsParser::_parse(
ScopedAllocation contents;
SLANG_RETURN_ON_FAIL(File::readAllBytes(fileName.value, contents));
SLANG_RETURN_ON_FAIL(m_session->loadStdLib(contents.getData(), contents.getSizeInBytes()));
+
+ // Ensure that the linkage's AST builder is up-to-date.
+ linkage->getASTBuilder()->m_cachedNodes = asInternal(m_session)->getGlobalASTBuilder()->m_cachedNodes;
+
break;
}
case OptionKind::CompileStdLib: m_compileStdLib = true; break;
diff --git a/tests/autodiff/self-differential-generic-type-synthesis.slang b/tests/autodiff/self-differential-generic-type-synthesis.slang
new file mode 100644
index 000000000..8d225dec2
--- /dev/null
+++ b/tests/autodiff/self-differential-generic-type-synthesis.slang
@@ -0,0 +1,36 @@
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+// Test that struct types made up of differentiable members who are self-differential (i.e. their Differential type is the same as their type)
+// are considered self-differential as well. We should be able to assign T.Differential = T and T = T.Differential without errors.
+//
+
+
+struct Ray<let N: int> : IDifferentiable {
+ float a;
+ vector<float, N> dir, o;
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ Ray<4> ray = Ray<4>();
+ Ray<4>.Differential ray2;
+
+ ray.a = 1.f;
+ ray.o = float4(3.f, 4.f, 2.5f, 1.f);
+
+ ray2 = ray;
+
+ float t = 0.f;
+ float.Differential dt = 0.f;
+
+ t = dt;
+
+ outputBuffer[0] = t;
+ outputBuffer[1] = ray2.o.y;
+ outputBuffer[2] = Ray<4>.dadd(ray2, ray2).o.w;
+}
diff --git a/tests/autodiff/self-differential-generic-type-synthesis.slang.expected.txt b/tests/autodiff/self-differential-generic-type-synthesis.slang.expected.txt
new file mode 100644
index 000000000..e3160fd7f
--- /dev/null
+++ b/tests/autodiff/self-differential-generic-type-synthesis.slang.expected.txt
@@ -0,0 +1,6 @@
+type: float
+0.000000
+4.000000
+2.000000
+0.000000
+0.000000
diff --git a/tests/autodiff/self-differential-type-synthesis.slang b/tests/autodiff/self-differential-type-synthesis.slang
new file mode 100644
index 000000000..7f95891c6
--- /dev/null
+++ b/tests/autodiff/self-differential-type-synthesis.slang
@@ -0,0 +1,36 @@
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+//TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -output-using-type -shaderobj
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+// Test that struct types made up of differentiable members who are self-differential (i.e. their Differential type is the same as their type)
+// are considered self-differential as well. We should be able to assign T.Differential = T and T = T.Differential without errors.
+// 1
+
+struct Ray : IDifferentiable {
+ float a;
+ float3 dir, o;
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ Ray ray = Ray();
+ Ray.Differential ray2;
+
+ ray.a = 1.f;
+ ray.o = float3(3.f, 4.f, 2.5f);
+
+ ray2 = ray;
+
+ float t = 0.f;
+ float.Differential dt = 0.f;
+
+ t = dt;
+
+ outputBuffer[0] = t;
+ outputBuffer[1] = ray2.o.y;
+ outputBuffer[2] = Ray.dadd(ray2, ray2).a;
+}
diff --git a/tests/autodiff/self-differential-type-synthesis.slang.expected.txt b/tests/autodiff/self-differential-type-synthesis.slang.expected.txt
new file mode 100644
index 000000000..e3160fd7f
--- /dev/null
+++ b/tests/autodiff/self-differential-type-synthesis.slang.expected.txt
@@ -0,0 +1,6 @@
+type: float
+0.000000
+4.000000
+2.000000
+0.000000
+0.000000
diff --git a/tests/compute/assoctype-nested-lookup.slang b/tests/compute/assoctype-nested-lookup.slang
new file mode 100644
index 000000000..518e88e25
--- /dev/null
+++ b/tests/compute/assoctype-nested-lookup.slang
@@ -0,0 +1,44 @@
+
+
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+
+//TEST_INPUT:ubuffer(data=[0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+interface IFoo
+{
+ associatedtype Bar : IFoo;
+};
+
+
+struct FooPair<T : IFoo> : IFoo
+{
+ T a;
+ T.Bar b;
+
+ typealias Bar = FooPair<T.Bar>;
+};
+
+
+struct ConcreteFoo : IFoo
+{
+ typealias Bar = ConcreteFoo;
+
+ float x;
+};
+
+void test(FooPair<ConcreteFoo>.Bar pair)
+{
+ pair.a.x = 1.0;
+ pair.b.x = 2.0;
+
+ outputBuffer[0] = pair.a.x + pair.b.x;
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
+{
+ FooPair<ConcreteFoo>.Bar pair;
+ test(pair);
+} \ No newline at end of file
diff --git a/tests/compute/assoctype-nested-lookup.slang.expected.txt b/tests/compute/assoctype-nested-lookup.slang.expected.txt
new file mode 100644
index 000000000..a6122d7ce
--- /dev/null
+++ b/tests/compute/assoctype-nested-lookup.slang.expected.txt
@@ -0,0 +1,2 @@
+type: float
+3.000000