diff options
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 44 | ||||
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 120 | ||||
| -rw-r--r-- | source/slang/slang-check-impl.h | 3 | ||||
| -rw-r--r-- | source/slang/slang-options.cpp | 4 | ||||
| -rw-r--r-- | tests/autodiff/self-differential-generic-type-synthesis.slang | 36 | ||||
| -rw-r--r-- | tests/autodiff/self-differential-generic-type-synthesis.slang.expected.txt | 6 | ||||
| -rw-r--r-- | tests/autodiff/self-differential-type-synthesis.slang | 36 | ||||
| -rw-r--r-- | tests/autodiff/self-differential-type-synthesis.slang.expected.txt | 6 | ||||
| -rw-r--r-- | tests/compute/assoctype-nested-lookup.slang | 44 | ||||
| -rw-r--r-- | tests/compute/assoctype-nested-lookup.slang.expected.txt | 2 |
10 files changed, 274 insertions, 27 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 921bd38e9..0d089874e 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -2144,7 +2144,40 @@ namespace Slang SLANG_RELEASE_ASSERT(aggTypeDecl); synth.pushContainerScope(aggTypeDecl); } - else + + // If we did not find an existing empty struct, we may need to synthesize one. + // But first, we check if the parent type can be used as its own differential type. + // + if (!aggTypeDecl + && as<AggTypeDecl>(context->parentDecl) + && canStructBeUsedAsSelfDifferentialType(as<AggTypeDecl>(context->parentDecl))) + { + // If the parent type can be used as its own differential type, we will create a typealias + // to itself as the differential type. + // + auto assocTypeDef = m_astBuilder->create<TypeDefDecl>(); + assocTypeDef->nameAndLoc.name = getName("Differential"); + assocTypeDef->type.type = context->conformingType; + assocTypeDef->parentDecl = context->parentDecl; + assocTypeDef->setCheckState(DeclCheckState::DefinitionChecked); + context->parentDecl->members.add(assocTypeDef); + + markSelfDifferentialMembersOfType(as<AggTypeDecl>(context->parentDecl), context->conformingType); + + if (doesTypeSatisfyAssociatedTypeConstraintRequirement(context->conformingType, requirementDeclRef, witnessTable)) + { + witnessTable->add(requirementDeclRef.getDecl(), RequirementWitness(context->conformingType)); + + // Increase the epoch so that future calls to Type::getCanonicalType will return the up-to-date folded types. + m_astBuilder->incrementEpoch(); + return true; + } + + // Something went wrong. + return false; + } + + if (!aggTypeDecl) { aggTypeDecl = m_astBuilder->create<StructDecl>(); aggTypeDecl->parentDecl = context->parentDecl; @@ -5741,6 +5774,15 @@ namespace Slang { checkConformance(type, inheritanceDecl, decl); } + + // Successful conformance checking may have created new witness tables. + // Increment epoch to invalidate the cache, so subsequent canonical types are + // re-calculated. + // + // TODO: Is it really necessary to invalidate globally? Maybe there's a way to invalidate only the + // types that are affected by these interface decls. + // + astBuilder->incrementEpoch(); } } diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index a399ea389..ff74f9a62 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -598,37 +598,60 @@ namespace Slang { case BuiltinRequirementKind::DifferentialType: { - auto structDecl = m_astBuilder->create<StructDecl>(); - auto conformanceDecl = m_astBuilder->create<InheritanceDecl>(); - conformanceDecl->base.type = m_astBuilder->getDiffInterfaceType(); - conformanceDecl->parentDecl = structDecl; - structDecl->members.add(conformanceDecl); - structDecl->parentDecl = parent; - - synthesizedDecl = structDecl; - auto typeDef = m_astBuilder->create<TypeAliasDecl>(); - typeDef->nameAndLoc.name = getName("Differential"); - typeDef->parentDecl = structDecl; - - auto synthDeclRef = createDefaultSubstitutionsIfNeeded(m_astBuilder, this, makeDeclRef(structDecl)); - - typeDef->type.type = DeclRefType::create(m_astBuilder, synthDeclRef); - structDecl->members.add(typeDef); + if (!canStructBeUsedAsSelfDifferentialType(parent)) + { + // Need to create a new struct type for the differential. + // + auto structDecl = m_astBuilder->create<StructDecl>(); + auto conformanceDecl = m_astBuilder->create<InheritanceDecl>(); + conformanceDecl->base.type = m_astBuilder->getDiffInterfaceType(); + conformanceDecl->parentDecl = structDecl; + structDecl->members.add(conformanceDecl); + structDecl->parentDecl = parent; + + synthesizedDecl = structDecl; + auto typeDef = m_astBuilder->create<TypeAliasDecl>(); + typeDef->nameAndLoc.name = getName("Differential"); + typeDef->parentDecl = structDecl; + + auto synthDeclRef = createDefaultSubstitutionsIfNeeded(m_astBuilder, this, makeDeclRef(structDecl)); + + typeDef->type.type = DeclRefType::create(m_astBuilder, synthDeclRef); + structDecl->members.add(typeDef); + + synthesizedDecl->parentDecl = parent; + synthesizedDecl->nameAndLoc.name = item.declRef.getName(); + synthesizedDecl->loc = parent->loc; + parent->members.add(synthesizedDecl); + parent->invalidateMemberDictionary(); + + // Mark the newly synthesized decl as `ToBeSynthesized` so future checking can differentiate it + // from user-provided definitions, and proceed to fill in its definition. + auto toBeSynthesized = m_astBuilder->create<ToBeSynthesizedModifier>(); + addModifier(synthesizedDecl, toBeSynthesized); + } + else + { + // There's no need for a new struct decl. + // We can simply add a typealias to the existing concrete type. + // + auto typeDef = m_astBuilder->create<TypeAliasDecl>(); + typeDef->nameAndLoc.name = item.declRef.getName(); + typeDef->parentDecl = parent; + typeDef->type.type = subType; + + synthesizedDecl = parent; + + parent->members.add(typeDef); + parent->invalidateMemberDictionary(); + + markSelfDifferentialMembersOfType(parent, subType); + } } break; default: return nullptr; } - synthesizedDecl->parentDecl = parent; - synthesizedDecl->nameAndLoc.name = item.declRef.getName(); - synthesizedDecl->loc = parent->loc; - parent->members.add(synthesizedDecl); - parent->invalidateMemberDictionary(); - - // Mark the newly synthesized decl as `ToBeSynthesized` so future checking can differentiate it - // from user-provided definitions, and proceed to fill in its definition. - auto toBeSynthesized = m_astBuilder->create<ToBeSynthesizedModifier>(); - addModifier(synthesizedDecl, toBeSynthesized); auto synthDeclMemberRef = m_astBuilder->getMemberDeclRef(subType->getDeclRef(), synthesizedDecl); return ConstructDeclRefExpr( @@ -1145,6 +1168,51 @@ namespace Slang return nullptr; } + bool SemanticsVisitor::canStructBeUsedAsSelfDifferentialType(AggTypeDecl *aggTypeDecl) + { + // A struct can be used as its own differential type if all its members are differentiable + // and their differential types are the same as the original types. + // + bool canBeUsed = true; + for (auto member : aggTypeDecl->members) + { + if (auto varDecl = as<VarDecl>(member)) + { + // Try to get the differential type of the member. + Type* diffType = tryGetDifferentialType(getASTBuilder(), varDecl->getType()); + if (!diffType || !diffType->equals(varDecl->getType())) + { + canBeUsed = false; + break; + } + } + } + return canBeUsed; + } + + void SemanticsVisitor::markSelfDifferentialMembersOfType(AggTypeDecl *parent, Type* type) + { + // TODO: Handle extensions. + // Add derivative member attributes to all the fields pointing to themselves. + for (auto member : parent->getMembersOfType<VarDeclBase>()) + { + auto derivativeMemberModifier = m_astBuilder->create<DerivativeMemberAttribute>(); + auto fieldLookupExpr = m_astBuilder->create<StaticMemberExpr>(); + fieldLookupExpr->type.type = member->getType(); + + auto baseTypeExpr = m_astBuilder->create<SharedTypeExpr>(); + baseTypeExpr->base.type = type; + auto baseTypeType = m_astBuilder->getOrCreate<TypeType>(type); + baseTypeExpr->type.type = baseTypeType; + fieldLookupExpr->baseExpression = baseTypeExpr; + + fieldLookupExpr->declRef = makeDeclRef(member); + + derivativeMemberModifier->memberDeclRef = fieldLookupExpr; + addModifier(member, derivativeMemberModifier); + } + } + Type* SemanticsVisitor::getDifferentialType(ASTBuilder* builder, Type* type, SourceLoc loc) { auto result = tryGetDifferentialType(builder, type); diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index d90c3c4b0..fc87c680b 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -1332,6 +1332,9 @@ namespace Slang Type* getDifferentialType(ASTBuilder* builder, Type* type, SourceLoc loc); Type* tryGetDifferentialType(ASTBuilder* builder, Type* type); + // Helper function to check if a struct can be used as its own differential type. + bool canStructBeUsedAsSelfDifferentialType(AggTypeDecl *aggTypeDecl); + void markSelfDifferentialMembersOfType(AggTypeDecl *parent, Type* type); public: diff --git a/source/slang/slang-options.cpp b/source/slang/slang-options.cpp index 857f4272c..805ea0fff 100644 --- a/source/slang/slang-options.cpp +++ b/source/slang/slang-options.cpp @@ -1707,6 +1707,10 @@ SlangResult OptionsParser::_parse( ScopedAllocation contents; SLANG_RETURN_ON_FAIL(File::readAllBytes(fileName.value, contents)); SLANG_RETURN_ON_FAIL(m_session->loadStdLib(contents.getData(), contents.getSizeInBytes())); + + // Ensure that the linkage's AST builder is up-to-date. + linkage->getASTBuilder()->m_cachedNodes = asInternal(m_session)->getGlobalASTBuilder()->m_cachedNodes; + break; } case OptionKind::CompileStdLib: m_compileStdLib = true; break; diff --git a/tests/autodiff/self-differential-generic-type-synthesis.slang b/tests/autodiff/self-differential-generic-type-synthesis.slang new file mode 100644 index 000000000..8d225dec2 --- /dev/null +++ b/tests/autodiff/self-differential-generic-type-synthesis.slang @@ -0,0 +1,36 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +// Test that struct types made up of differentiable members who are self-differential (i.e. their Differential type is the same as their type) +// are considered self-differential as well. We should be able to assign T.Differential = T and T = T.Differential without errors. +// + + +struct Ray<let N: int> : IDifferentiable { + float a; + vector<float, N> dir, o; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + Ray<4> ray = Ray<4>(); + Ray<4>.Differential ray2; + + ray.a = 1.f; + ray.o = float4(3.f, 4.f, 2.5f, 1.f); + + ray2 = ray; + + float t = 0.f; + float.Differential dt = 0.f; + + t = dt; + + outputBuffer[0] = t; + outputBuffer[1] = ray2.o.y; + outputBuffer[2] = Ray<4>.dadd(ray2, ray2).o.w; +} diff --git a/tests/autodiff/self-differential-generic-type-synthesis.slang.expected.txt b/tests/autodiff/self-differential-generic-type-synthesis.slang.expected.txt new file mode 100644 index 000000000..e3160fd7f --- /dev/null +++ b/tests/autodiff/self-differential-generic-type-synthesis.slang.expected.txt @@ -0,0 +1,6 @@ +type: float +0.000000 +4.000000 +2.000000 +0.000000 +0.000000 diff --git a/tests/autodiff/self-differential-type-synthesis.slang b/tests/autodiff/self-differential-type-synthesis.slang new file mode 100644 index 000000000..7f95891c6 --- /dev/null +++ b/tests/autodiff/self-differential-type-synthesis.slang @@ -0,0 +1,36 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE_EX:-cpu -compute -output-using-type -shaderobj + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +// Test that struct types made up of differentiable members who are self-differential (i.e. their Differential type is the same as their type) +// are considered self-differential as well. We should be able to assign T.Differential = T and T = T.Differential without errors. +// 1 + +struct Ray : IDifferentiable { + float a; + float3 dir, o; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + Ray ray = Ray(); + Ray.Differential ray2; + + ray.a = 1.f; + ray.o = float3(3.f, 4.f, 2.5f); + + ray2 = ray; + + float t = 0.f; + float.Differential dt = 0.f; + + t = dt; + + outputBuffer[0] = t; + outputBuffer[1] = ray2.o.y; + outputBuffer[2] = Ray.dadd(ray2, ray2).a; +} diff --git a/tests/autodiff/self-differential-type-synthesis.slang.expected.txt b/tests/autodiff/self-differential-type-synthesis.slang.expected.txt new file mode 100644 index 000000000..e3160fd7f --- /dev/null +++ b/tests/autodiff/self-differential-type-synthesis.slang.expected.txt @@ -0,0 +1,6 @@ +type: float +0.000000 +4.000000 +2.000000 +0.000000 +0.000000 diff --git a/tests/compute/assoctype-nested-lookup.slang b/tests/compute/assoctype-nested-lookup.slang new file mode 100644 index 000000000..518e88e25 --- /dev/null +++ b/tests/compute/assoctype-nested-lookup.slang @@ -0,0 +1,44 @@ + + +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +interface IFoo +{ + associatedtype Bar : IFoo; +}; + + +struct FooPair<T : IFoo> : IFoo +{ + T a; + T.Bar b; + + typealias Bar = FooPair<T.Bar>; +}; + + +struct ConcreteFoo : IFoo +{ + typealias Bar = ConcreteFoo; + + float x; +}; + +void test(FooPair<ConcreteFoo>.Bar pair) +{ + pair.a.x = 1.0; + pair.b.x = 2.0; + + outputBuffer[0] = pair.a.x + pair.b.x; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) +{ + FooPair<ConcreteFoo>.Bar pair; + test(pair); +}
\ No newline at end of file diff --git a/tests/compute/assoctype-nested-lookup.slang.expected.txt b/tests/compute/assoctype-nested-lookup.slang.expected.txt new file mode 100644 index 000000000..a6122d7ce --- /dev/null +++ b/tests/compute/assoctype-nested-lookup.slang.expected.txt @@ -0,0 +1,2 @@ +type: float +3.000000 |
