diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2024-05-02 19:01:21 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-05-02 16:01:21 -0700 |
| commit | c763750a7305fbf12c1f5c177260294a32fe286d (patch) | |
| tree | ec544212bd5bf0ac72cbe3c01d2ed85fc3f05cd9 | |
| parent | e5d49cf21db7a398afe6cfdb76f6b4a028e9eecb (diff) | |
Handle case where types can be used as their own `Differential` type. (#4057)
* Avoid synthesis for when types can be used as their own differenial
+ Add test
* Add missing files..
* Fix issue with method synthesis for self-differential types
+ Add a generic test
* Fix
* Fix issue with out-of-date type resolution cache.
Witness tables created during the conformance checking phase not being taken into account during the decl type resolution phase because the epoch is not updated after conformance checking.
This leads to certain complex associated-type lookup chains (such as the one in tests/compute/assoctype-nested-lookup) not resolving properly and causing errors.
* Delete self-differential-type-synthesis-extension.slang
* Quick fix to repopulate stdlib cache for deferred stdlib loading
* Update slang-check-decl.cpp
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 44 | ||||
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 120 | ||||
| -rw-r--r-- | source/slang/slang-check-impl.h | 3 | ||||
| -rw-r--r-- | source/slang/slang-options.cpp | 4 | ||||
| -rw-r--r-- | tests/autodiff/self-differential-generic-type-synthesis.slang | 36 | ||||
| -rw-r--r-- | tests/autodiff/self-differential-generic-type-synthesis.slang.expected.txt | 6 | ||||
| -rw-r--r-- | tests/autodiff/self-differential-type-synthesis.slang | 36 | ||||
| -rw-r--r-- | tests/autodiff/self-differential-type-synthesis.slang.expected.txt | 6 | ||||
| -rw-r--r-- | tests/compute/assoctype-nested-lookup.slang | 44 | ||||
| -rw-r--r-- | tests/compute/assoctype-nested-lookup.slang.expected.txt | 2 |
10 files changed, 274 insertions, 27 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 921bd38e9..0d089874e 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -2144,7 +2144,40 @@ namespace Slang SLANG_RELEASE_ASSERT(aggTypeDecl); synth.pushContainerScope(aggTypeDecl); } - else + + // If we did not find an existing empty struct, we may need to synthesize one. + // But first, we check if the parent type can be used as its own differential type. + // + if (!aggTypeDecl + && as<AggTypeDecl>(context->parentDecl) + && canStructBeUsedAsSelfDifferentialType(as<AggTypeDecl>(context->parentDecl))) + { + // If the parent type can be used as its own differential type, we will create a typealias + // to itself as the differential type. + // + auto assocTypeDef = m_astBuilder->create<TypeDefDecl>(); + assocTypeDef->nameAndLoc.name = getName("Differential"); + assocTypeDef->type.type = context->conformingType; + assocTypeDef->parentDecl = context->parentDecl; + assocTypeDef->setCheckState(DeclCheckState::DefinitionChecked); + context->parentDecl->members.add(assocTypeDef); + + markSelfDifferentialMembersOfType(as<AggTypeDecl>(context->parentDecl), context->conformingType); + + if (doesTypeSatisfyAssociatedTypeConstraintRequirement(context->conformingType, requirementDeclRef, witnessTable)) + { + witnessTable->add(requirementDeclRef.getDecl(), RequirementWitness(context->conformingType)); + + // Increase the epoch so that future calls to Type::getCanonicalType will return the up-to-date folded types. + m_astBuilder->incrementEpoch(); + return true; + } + + // Something went wrong. + return false; + } + + if (!aggTypeDecl) { aggTypeDecl = m_astBuilder->create<StructDecl>(); aggTypeDecl->parentDecl = context->parentDecl; @@ -5741,6 +5774,15 @@ namespace Slang { checkConformance(type, inheritanceDecl, decl); } + + // Successful conformance checking may have created new witness tables. + // Increment epoch to invalidate the cache, so subsequent canonical types are + // re-calculated. + // + // TODO: Is it really necessary to invalidate globally? Maybe there's a way to invalidate only the + // types that are affected by these interface decls. + // + astBuilder->incrementEpoch(); } } diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index a399ea389..ff74f9a62 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -598,37 +598,60 @@ namespace Slang { case BuiltinRequirementKind::DifferentialType: { - auto structDecl = m_astBuilder->create<StructDecl>(); - auto conformanceDecl = m_astBuilder->create<InheritanceDecl>(); - conformanceDecl->base.type = m_astBuilder->getDiffInterfaceType(); - conformanceDecl->parentDecl = structDecl; - structDecl->members.add(conformanceDecl); - structDecl->parentDecl = parent; - - synthesizedDecl = structDecl; - auto typeDef = m_astBuilder->create<TypeAliasDecl>(); - typeDef->nameAndLoc.name = getName("Differential"); - typeDef->parentDecl = structDecl; - - auto synthDeclRef = createDefaultSubstitutionsIfNeeded(m_astBuilder, this, makeDeclRef(structDecl)); - - typeDef->type.type = DeclRefType::create(m_astBuilder, synthDeclRef); - structDecl->members.add(typeDef); + if (!canStructBeUsedAsSelfDifferentialType(parent)) + { + // Need to create a new struct type for the differential. + // + auto structDecl = m_astBuilder->create<StructDecl>(); + auto conformanceDecl = m_astBuilder->create<InheritanceDecl>(); + conformanceDecl->base.type = m_astBuilder->getDiffInterfaceType(); + conformanceDecl->parentDecl = structDecl; + structDecl->members.add(conformanceDecl); + structDecl->parentDecl = parent; + + synthesizedDecl = structDecl; + auto typeDef = m_astBuilder->create<TypeAliasDecl>(); + typeDef->nameAndLoc.name = getName("Differential"); + typeDef->parentDecl = structDecl; + + auto synthDeclRef = createDefaultSubstitutionsIfNeeded(m_astBuilder, this, makeDeclRef(structDecl)); + + typeDef->type.type = DeclRefType::create(m_astBuilder, synthDeclRef); + structDecl->members.add(typeDef); + + synthesizedDecl->parentDecl = parent; + synthesizedDecl->nameAndLoc.name = item.declRef.getName(); + synthesizedDecl->loc = parent->loc; + parent->members.add(synthesizedDecl); + parent->invalidateMemberDictionary(); + + // Mark the newly synthesized decl as `ToBeSynthesized` so future checking can differentiate it + // from user-provided definitions, and proceed to fill in its definition. + auto toBeSynthesized = m_astBuilder->create<ToBeSynthesizedModifier>(); + addModifier(synthesizedDecl, toBeSynthesized); + } + else + { + // There's no need for a new struct decl. + // We can simply add a typealias to the existing concrete type. + // + auto typeDef = m_astBuilder->create<TypeAliasDecl>(); + typeDef->nameAndLoc.name = item.declRef.getName(); + typeDef->parentDecl = parent; + typeDef->type.type = subType; + + synthesizedDecl = parent; + + parent->members.add(typeDef); + parent->invalidateMemberDictionary(); + + markSelfDifferentialMembersOfType(parent, subType); + } } break; default: return nullptr; } - synthesizedDecl->parentDecl = parent; - synthesizedDecl->nameAndLoc.name = item.declRef.getName(); - synthesizedDecl->loc = parent->loc; - parent->members.add(synthesizedDecl); - parent->invalidateMemberDictionary(); - - // Mark the newly synthesized decl as `ToBeSynthesized` so future checking can differentiate it - // from user-provided definitions, and proceed to fill in its definition. - auto toBeSynthesized = m_astBuilder->create<ToBeSynthesizedModifier>(); - addModifier(synthesizedDecl, toBeSynthesized); auto synthDeclMemberRef = m_astBuilder->getMemberDeclRef(subType->getDeclRef(), synthesizedDecl); return ConstructDeclRefExpr( @@ -1145,6 +1168,51 @@ namespace Slang return nullptr; } + bool SemanticsVisitor::canStructBeUsedAsSelfDifferentialType(AggTypeDecl *aggTypeDecl) + { + // A struct can be used as its own differential type if all its members are differentiable + // and their differential types are the same as the original types. + // + bool canBeUsed = true; + for (auto member : aggTypeDecl->members) + { + if (auto varDecl = as<VarDecl>(member)) + { + // Try to get the differential type of the member. + Type* diffType = tryGetDifferentialType(getASTBuilder(), varDecl->getType()); + if (!diffType || !diffType->equals(varDecl->getType())) + { + canBeUsed = false; + break; + } + } + } + return canBeUsed; + } + + void SemanticsVisitor::markSelfDifferentialMembersOfType(AggTypeDecl *parent, Type* type) + { + // TODO: Handle extensions. + // Add derivative member attributes to all the fields pointing to themselves. + for (auto member : parent->getMembersOfType<VarDeclBase>()) + { + auto derivativeMemberModifier = m_astBuilder->create<DerivativeMemberAttribute>(); + auto fieldLookupExpr = m_astBuilder->create<StaticMemberExpr>(); + fieldLookupExpr->type.type = member->getType(); + + auto baseTypeExpr = m_astBuilder->create<SharedTypeExpr>(); + baseTypeExpr->base.type = type; + auto baseTypeType = m_astBuilder->getOrCreate<TypeType>(type); + baseTypeExpr->type.type = baseTypeType; + fieldLookupExpr->baseExpression = baseTypeExpr; + + fieldLookupExpr->declRef = makeDeclRef(member); + + derivativeMemberModifier->memberDeclRef = fieldLookupExpr; + addModifier(member, derivativeMemberModifier); + } + } + Type* SemanticsVisitor::getDifferentialType(ASTBuilder* builder, Type* type, SourceLoc loc) { auto result = tryGetDifferentialType(builder, type); diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index d90c3c4b0..fc87c680b 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -1332,6 +1332,9 @@ namespace Slang Type* getDifferentialType(ASTBuilder* builder, Type* type, SourceLoc loc); Type* tryGetDifferentialType(ASTBuilder* builder, Type* type); + // Helper function to check if a struct can be used as its own differential type. + bool canStructBeUsedAsSelfDifferentialType(AggTypeDecl *aggTypeDecl); + void markSelfDifferentialMembersOfType(AggTypeDecl *parent, Type* type); public: diff --git a/source/slang/slang-options.cpp b/source/slang/slang-options.cpp index 857f4272c..805ea0fff 100644 --- a/source/slang/slang-options.cpp +++ b/source/slang/slang-options.cpp @@ -1707,6 +1707,10 @@ SlangResult OptionsParser::_parse( ScopedAllocation contents; SLANG_RETURN_ON_FAIL(File::readAllBytes(fileName.value, contents)); SLANG_RETURN_ON_FAIL(m_session->loadStdLib(contents.getData(), contents.getSizeInBytes())); + + // Ensure that the linkage's AST builder is up-to-date. + linkage->getASTBuilder()->m_cachedNodes = asInternal(m_session)->getGlobalASTBuilder()->m_cachedNodes; + break; } case OptionKind::CompileStdLib: m_compileStdLib = true; break; diff --git a/tests/autodiff/self-differential-generic-type-synthesis.slang b/tests/autodiff/self-differential-generic-type-synthesis.slang new file mode 100644 index 000000000..8d225dec2 --- /dev/null +++ b/tests/autodiff/self-differential-generic-type-synthesis.slang @@ -0,0 +1,36 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +// Test that struct types made up of differentiable members who are self-differential (i.e. their Differential type is the same as their type) +// are considered self-differential as well. We should be able to assign T.Differential = T and T = T.Differential without errors. +// + + +struct Ray<let N: int> : IDifferentiable { + float a; + vector<float, N> dir, o; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + Ray<4> ray = Ray<4>(); + Ray<4>.Differential ray2; + + ray.a = 1.f; + ray.o = float4(3.f, 4.f, 2.5f, 1.f); + + ray2 = ray; + + float t = 0.f; + float.Differential dt = 0.f; + + t = dt; + + outputBuffer[0] = t; + outputBuffer[1] = ray2.o.y; + outputBuffer[2] = Ray<4>.dadd(ray2, ray2).o.w; +} diff --git a/tests/autodiff/self-differential-generic-type-synthesis.slang.expected.txt b/tests/autodiff/self-differential-generic-type-synthesis.slang.expected.txt new file mode 100644 index 000000000..e3160fd7f --- /dev/null +++ b/tests/autodiff/self-differential-generic-type-synthesis.slang.expected.txt @@ -0,0 +1,6 @@ +type: float +0.000000 +4.000000 +2.000000 +0.000000 +0.000000 diff --git a/tests/autodiff/self-differential-type-synthesis.slang b/tests/autodiff/self-differential-type-synthesis.slang new file mode 100644 index 000000000..7f95891c6 --- /dev/null +++ b/tests/autodiff/self-differential-type-synthesis.slang @@ -0,0 +1,36 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -output-using-type -shaderobj + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +// Test that struct types made up of differentiable members who are self-differential (i.e. their Differential type is the same as their type) +// are considered self-differential as well. We should be able to assign T.Differential = T and T = T.Differential without errors. +// 1 + +struct Ray : IDifferentiable { + float a; + float3 dir, o; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + Ray ray = Ray(); + Ray.Differential ray2; + + ray.a = 1.f; + ray.o = float3(3.f, 4.f, 2.5f); + + ray2 = ray; + + float t = 0.f; + float.Differential dt = 0.f; + + t = dt; + + outputBuffer[0] = t; + outputBuffer[1] = ray2.o.y; + outputBuffer[2] = Ray.dadd(ray2, ray2).a; +} diff --git a/tests/autodiff/self-differential-type-synthesis.slang.expected.txt b/tests/autodiff/self-differential-type-synthesis.slang.expected.txt new file mode 100644 index 000000000..e3160fd7f --- /dev/null +++ b/tests/autodiff/self-differential-type-synthesis.slang.expected.txt @@ -0,0 +1,6 @@ +type: float +0.000000 +4.000000 +2.000000 +0.000000 +0.000000 diff --git a/tests/compute/assoctype-nested-lookup.slang b/tests/compute/assoctype-nested-lookup.slang new file mode 100644 index 000000000..518e88e25 --- /dev/null +++ b/tests/compute/assoctype-nested-lookup.slang @@ -0,0 +1,44 @@ + + +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +interface IFoo +{ + associatedtype Bar : IFoo; +}; + + +struct FooPair<T : IFoo> : IFoo +{ + T a; + T.Bar b; + + typealias Bar = FooPair<T.Bar>; +}; + + +struct ConcreteFoo : IFoo +{ + typealias Bar = ConcreteFoo; + + float x; +}; + +void test(FooPair<ConcreteFoo>.Bar pair) +{ + pair.a.x = 1.0; + pair.b.x = 2.0; + + outputBuffer[0] = pair.a.x + pair.b.x; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) +{ + FooPair<ConcreteFoo>.Bar pair; + test(pair); +}
\ No newline at end of file diff --git a/tests/compute/assoctype-nested-lookup.slang.expected.txt b/tests/compute/assoctype-nested-lookup.slang.expected.txt new file mode 100644 index 000000000..a6122d7ce --- /dev/null +++ b/tests/compute/assoctype-nested-lookup.slang.expected.txt @@ -0,0 +1,2 @@ +type: float +3.000000 |
