diff options
| -rw-r--r-- | source/slang/slang-ast-type.cpp | 11 | ||||
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 50 | ||||
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 28 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff.cpp | 12 | ||||
| -rw-r--r-- | source/slang/slang-syntax.cpp | 37 | ||||
| -rw-r--r-- | tests/autodiff/diff-assoc-type.slang | 30 | ||||
| -rw-r--r-- | tests/autodiff/diff-assoc-type.slang.expected.txt | 2 |
7 files changed, 143 insertions, 27 deletions
diff --git a/source/slang/slang-ast-type.cpp b/source/slang/slang-ast-type.cpp index 362503a64..fdbd56377 100644 --- a/source/slang/slang-ast-type.cpp +++ b/source/slang/slang-ast-type.cpp @@ -208,8 +208,19 @@ Val* DeclRefType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSe if (auto genericTypeParamDecl = as<GenericTypeParamDecl>(declRef.getDecl())) { if (auto result = maybeSubstituteGenericParam(this, genericTypeParamDecl, subst, ioDiff)) + { + if (auto substDeclRefType = as<DeclRefType>(result)) + { + // After generic substitution, we may be able to further simplify + // by looking up the actual type of an associated type. + if (auto satisfyingVal = _tryLookupConcreteAssociatedTypeFromThisTypeSubst( + astBuilder, substDeclRefType->declRef)) + return satisfyingVal; + } return result; + } } + int diff = 0; DeclRef<Decl> substDeclRef = declRef.substituteImpl(astBuilder, subst, &diff); diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 580ad8402..c0253fd2c 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -1543,18 +1543,48 @@ namespace Slang }; // Make the Differential type itself conform to `IDifferential` interface. - auto inheritanceIDiffernetiable = m_astBuilder->create<InheritanceDecl>(); - inheritanceIDiffernetiable->base.type = m_astBuilder->getDiffInterfaceType(); - inheritanceIDiffernetiable->parentDecl = aggTypeDecl; - aggTypeDecl->members.add(inheritanceIDiffernetiable); + bool hasDifferentialConformance = false; + for (auto inheritanceDecl : aggTypeDecl->getMembersOfType<InheritanceDecl>()) + { + if (auto declRefType = as<DeclRefType>(inheritanceDecl->base.type)) + { + if (declRefType->declRef == m_astBuilder->getDifferentiableInterface()) + { + hasDifferentialConformance = true; + break; + } + } + } + if (!hasDifferentialConformance) + { + auto inheritanceIDiffernetiable = m_astBuilder->create<InheritanceDecl>(); + inheritanceIDiffernetiable->base.type = m_astBuilder->getDiffInterfaceType(); + inheritanceIDiffernetiable->parentDecl = aggTypeDecl; + aggTypeDecl->members.add(inheritanceIDiffernetiable); + } // The `Differential` type of a `Differential` type is always itself. - auto assocTypeDef = m_astBuilder->create<TypeDefDecl>(); - assocTypeDef->nameAndLoc.name = getName("Differential"); - assocTypeDef->type.type = satisfyingType; - assocTypeDef->parentDecl = aggTypeDecl; - assocTypeDef->setCheckState(DeclCheckState::Checked); - aggTypeDecl->members.add(assocTypeDef); + bool hasDifferentialTypeDef = false; + for (auto member : aggTypeDecl->members) + { + if (auto name = member->getName()) + { + if (name->text == "Differential") + { + hasDifferentialTypeDef = true; + break; + } + } + } + if (!hasDifferentialTypeDef) + { + auto assocTypeDef = m_astBuilder->create<TypeDefDecl>(); + assocTypeDef->nameAndLoc.name = getName("Differential"); + assocTypeDef->type.type = satisfyingType; + assocTypeDef->parentDecl = aggTypeDecl; + assocTypeDef->setCheckState(DeclCheckState::Checked); + aggTypeDecl->members.add(assocTypeDef); + } // Go through all members and collect their differential types. // Go through super types. diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 8d8a72dd6..bfad1dbfe 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -487,12 +487,25 @@ namespace Slang switch (builtinAssocTypeAttr->kind) { case BuiltinRequirementKind::DifferentialType: - synthesizedDecl = m_astBuilder->create<StructDecl>(); + { + auto structDecl = m_astBuilder->create<StructDecl>(); + auto conformanceDecl = m_astBuilder->create<InheritanceDecl>(); + conformanceDecl->base.type = m_astBuilder->getDiffInterfaceType(); + conformanceDecl->parentDecl = structDecl; + structDecl->members.add(conformanceDecl); + + synthesizedDecl = structDecl; + auto typeDef = m_astBuilder->create<TypeAliasDecl>(); + typeDef->nameAndLoc.name = getName("Differential"); + auto declRef = createDefaultSubstitutionsIfNeeded(m_astBuilder, this, DeclRef<Decl>(structDecl, nullptr)); + typeDef->type.type = m_astBuilder->getOrCreateDeclRefType(declRef.decl, declRef.substitutions); + typeDef->parentDecl = structDecl; + structDecl->members.add(typeDef); + } break; default: - break; + return nullptr; } - synthesizedDecl = m_astBuilder->create<StructDecl>(); synthesizedDecl->parentDecl = parent; synthesizedDecl->nameAndLoc.name = item.declRef.getName(); synthesizedDecl->loc = parent->loc; @@ -645,6 +658,15 @@ namespace Slang default: SLANG_UNREACHABLE("all cases handle"); } + if (getShared()->isInLanguageServer()) + { + // Don't make breadcrumb nodes carry any source loc info, + // as they may confuse language server functionalities. + if (bb) + { + bb->loc = SourceLoc(); + } + } } return ConstructDeclRefExpr(item.declRef, bb, loc, originalExpr); diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index e9b78696e..224cca9e0 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -384,6 +384,18 @@ void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func) else { differentiableWitnessDictionary.Add((IRType*)item->getConcreteType(), item->getWitness()); + if (auto diffPairType = as<IRDifferentialPairTypeBase>(item->getConcreteType())) + { + // For differential pair types, register the differential type as well. + IRBuilder builder(diffPairType); + builder.setInsertAfter(diffPairType->getWitness()); + auto diffType = _lookupWitness(&builder, diffPairType->getWitness(), sharedContext->differentialAssocTypeStructKey); + auto diffWitness = _lookupWitness(&builder, diffPairType->getWitness(), sharedContext->differentialAssocTypeWitnessStructKey); + if (diffType && diffWitness) + { + differentiableWitnessDictionary.AddIfNotExists((IRType*)diffType, diffWitness); + } + } } } } diff --git a/source/slang/slang-syntax.cpp b/source/slang/slang-syntax.cpp index 470f5f983..27aba435f 100644 --- a/source/slang/slang-syntax.cpp +++ b/source/slang/slang-syntax.cpp @@ -1234,23 +1234,32 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt } // Hard code implementation of T.Differential.Differential == T.Differential rule. - if (auto builtinReq = substDeclRef.getDecl()->findModifier<BuiltinRequirementModifier>()) + auto foldResult = [&]() -> Val* { - if (builtinReq->kind == BuiltinRequirementKind::DifferentialType) + auto builtinReq = substDeclRef.getDecl()->findModifier<BuiltinRequirementModifier>(); + if (!builtinReq) + return nullptr; + if (builtinReq->kind != BuiltinRequirementKind::DifferentialType) + return nullptr; + // Is the concrete type a Differential associated type? + auto innerDeclRefType = as<DeclRefType>(thisSubst->witness->sub); + if (!innerDeclRefType) + return nullptr; + auto innerBuiltinReq = innerDeclRefType->declRef.decl->findModifier<BuiltinRequirementModifier>(); + if (!innerBuiltinReq) + return nullptr; + if (innerBuiltinReq->kind != BuiltinRequirementKind::DifferentialType) + return nullptr; + if (!innerDeclRefType->declRef.equals(declRef)) { - // Is the concrete type a Differential associated type? - if (auto innerDeclRefType = as<DeclRefType>(thisSubst->witness->sub)) - { - if (auto innerBuiltinReq = innerDeclRefType->declRef.decl->findModifier<BuiltinRequirementModifier>()) - { - if (innerBuiltinReq->kind == BuiltinRequirementKind::DifferentialType) - { - return innerDeclRefType; - } - } - } + auto result = _tryLookupConcreteAssociatedTypeFromThisTypeSubst(builder, innerDeclRefType->declRef); + if (result) + return result; } - } + return innerDeclRefType; + }(); + if (foldResult) + return foldResult; } } } diff --git a/tests/autodiff/diff-assoc-type.slang b/tests/autodiff/diff-assoc-type.slang new file mode 100644 index 000000000..60a80b32c --- /dev/null +++ b/tests/autodiff/diff-assoc-type.slang @@ -0,0 +1,30 @@ +// Tests automatic synthesis of Differential type and method requirements. + +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +struct B : IDifferentiable +{ + float x; +} + +float myFunc(DifferentialPair<B.Differential> d) +{ + return d.p.x + d.d.x; +} + +float myFunc2(DifferentialPair<B>.Differential d) +{ + return d.p.x + d.d.x; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + B.Differential bd; + bd.x = 1.0; + outputBuffer[0] = myFunc(diffPair(bd, bd)) + myFunc2(diffPair(bd, bd)); +} diff --git a/tests/autodiff/diff-assoc-type.slang.expected.txt b/tests/autodiff/diff-assoc-type.slang.expected.txt new file mode 100644 index 000000000..bc795a8ba --- /dev/null +++ b/tests/autodiff/diff-assoc-type.slang.expected.txt @@ -0,0 +1,2 @@ +type: float +4.0 |
