diff options
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/core.meta.slang | 6 | ||||
| -rw-r--r-- | source/slang/slang-ast-decl.h | 1 | ||||
| -rw-r--r-- | source/slang/slang-ast-expr.h | 12 | ||||
| -rw-r--r-- | source/slang/slang-ast-modifier.h | 21 | ||||
| -rw-r--r-- | source/slang/slang-ast-support-types.h | 1 | ||||
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 104 | ||||
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 157 | ||||
| -rw-r--r-- | source/slang/slang-check-impl.h | 47 | ||||
| -rw-r--r-- | source/slang/slang-check-modifier.cpp | 60 | ||||
| -rw-r--r-- | source/slang/slang-check.cpp | 20 | ||||
| -rw-r--r-- | source/slang/slang-diagnostic-defs.h | 11 | ||||
| -rw-r--r-- | source/slang/slang-emit.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-diff-jvp.cpp | 229 | ||||
| -rw-r--r-- | source/slang/slang-ir-diff-jvp.h | 1 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 4 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 27 | ||||
| -rw-r--r-- | source/slang/slang-lookup.cpp | 45 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 79 |
18 files changed, 414 insertions, 413 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 1711102da..769a1091d 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -2737,6 +2737,12 @@ __attributeTarget(InterfaceDecl) attribute_syntax [Specialize] : SpecializeAttribute; __attributeTarget(DeclBase) +attribute_syntax [Differentiable] : DifferentiableAttribute; + +__attributeTarget(DeclBase) +attribute_syntax [DerivativeMember(memberName)] : DerivativeMemberAttribute; + +__attributeTarget(DeclBase) attribute_syntax [builtin] : BuiltinAttribute; __attributeTarget(DeclBase) diff --git a/source/slang/slang-ast-decl.h b/source/slang/slang-ast-decl.h index 07cfe6a0c..b1b20dc93 100644 --- a/source/slang/slang-ast-decl.h +++ b/source/slang/slang-ast-decl.h @@ -337,7 +337,6 @@ class RefAccessorDecl : public AccessorDecl { SLANG_AST_CLASS(RefAccessorDecl) }; - class FuncDecl : public FunctionDeclBase { SLANG_AST_CLASS(FuncDecl) diff --git a/source/slang/slang-ast-expr.h b/source/slang/slang-ast-expr.h index e0a55cc29..baa6de73a 100644 --- a/source/slang/slang-ast-expr.h +++ b/source/slang/slang-ast-expr.h @@ -38,18 +38,6 @@ class VarExpr : public DeclRefExpr SLANG_AST_CLASS(VarExpr) }; -class DifferentiableDeclRefExpr : public Expr -{ - SLANG_AST_CLASS(DifferentiableDeclRefExpr) - - // Inner decl ref expr that references a differentiable expression. - Expr* inner = nullptr; - - // Information on getters and setters if available. - Expr* setterExpr = nullptr; - Expr* getterExpr = nullptr; -}; - // An expression that references an overloaded set of declarations // having the same name. class OverloadedExpr : public Expr diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index 8230f481e..b019953cb 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -32,6 +32,14 @@ class GloballyCoherentModifier : public Modifier { SLANG_AST_CLASS(GloballyCoher class ExternCppModifier : public Modifier { SLANG_AST_CLASS(ExternCppModifier)}; class JVPDerivativeModifier : public Modifier { SLANG_AST_CLASS(JVPDerivativeModifier)}; +// An `extern` variable in an extension is used to introduce additional attributes on an existing +// field. +class ExtensionExternVarModifier : public Modifier +{ + SLANG_AST_CLASS(ExtensionExternVarModifier) + DeclRef<Decl> originalDecl; +}; + // An 'ActualGlobal' is a global that is output as a normal global in CPU code. // Globals in HLSL/Slang are constant state passed into kernel execution class ActualGlobalModifier : public Modifier { SLANG_AST_CLASS(ActualGlobalModifier)}; @@ -951,6 +959,12 @@ class SpecializeAttribute : public Attribute SLANG_AST_CLASS(SpecializeAttribute) }; + /// An attribute that marks a type, function or variable as differentiable. +class DifferentiableAttribute : public Attribute +{ + SLANG_AST_CLASS(DifferentiableAttribute) +}; + class DllImportAttribute : public Attribute { SLANG_AST_CLASS(DllImportAttribute) @@ -965,6 +979,13 @@ class DllExportAttribute : public Attribute SLANG_AST_CLASS(DllExportAttribute) }; +class DerivativeMemberAttribute : public Attribute +{ + SLANG_AST_CLASS(DerivativeMemberAttribute) + + DeclRefExpr* memberDeclRef; +}; + /// An attribute that marks an interface type as a COM interface declaration. class ComInterfaceAttribute : public Attribute { diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h index d6f9a305b..39ca71267 100644 --- a/source/slang/slang-ast-support-types.h +++ b/source/slang/slang-ast-support-types.h @@ -1115,7 +1115,6 @@ namespace Slang Function = 0x2, Value = 0x4, Attribute = 0x8, - Default = type | Function | Value, }; diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 2d6e20622..356105e4f 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -45,7 +45,10 @@ namespace Slang void visitDecl(Decl*) {} void visitDeclGroup(DeclGroup*) {} - + + void checkDerivativeMemberAttribute(VarDeclBase* varDecl, DerivativeMemberAttribute* attr); + void checkExtensionExternVarAttribute(VarDeclBase* varDecl, ExtensionExternVarModifier* m); + void checkVarDeclCommon(VarDeclBase* varDecl); void visitVarDecl(VarDecl* varDecl) @@ -78,6 +81,8 @@ namespace Slang void checkCallableDeclCommon(CallableDecl* decl); + void maybeCheckDifferentiableAccessorSignature(FuncDecl* funcDecl); + void visitFuncDecl(FuncDecl* funcDecl); void visitParamDecl(ParamDecl* paramDecl); @@ -636,6 +641,9 @@ namespace Slang bool SemanticsVisitor::isDeclUsableAsStaticMember( Decl* decl) { + if (m_allowStaticReferenceToNonStaticMember) + return true; + if(auto genericDecl = as<GenericDecl>(decl)) decl = genericDecl->inner; @@ -663,6 +671,9 @@ namespace Slang bool SemanticsVisitor::isUsableAsStaticMember( LookupResultItem const& item) { + if (m_allowStaticReferenceToNonStaticMember) + return true; + // There's a bit of a gotcha here, because a lookup result // item might include "breadcrumbs" that indicate more steps // along the lookup path. As a result it isn't always @@ -966,6 +977,87 @@ namespace Slang tryConstantFoldDeclRef(DeclRef<VarDeclBase>(varDecl, nullptr), nullptr); } + void SemanticsDeclHeaderVisitor::checkDerivativeMemberAttribute( + VarDeclBase* varDecl, DerivativeMemberAttribute* derivativeMemberAttr) + { + auto memberType = checkProperType(getLinkage(), varDecl->type, getSink()); + auto diffType = _getDifferential(m_astBuilder, memberType); + if (as<ErrorType>(diffType)) + { + getSink()->diagnose(derivativeMemberAttr, Diagnostics::typeIsNotDifferentiable, memberType); + } + auto thisType = calcThisType(makeDeclRef(varDecl->parentDecl)); + if (!thisType) + { + getSink()->diagnose( + derivativeMemberAttr, + Diagnostics:: + derivativeMemberAttributeCanOnlyBeUsedOnMembers); + } + auto diffThisType = _getDifferential(m_astBuilder, thisType); + if (!thisType) + { + getSink()->diagnose( + derivativeMemberAttr, + Diagnostics::invalidUseOfDerivativeMemberAttributeParentTypeIsNotDifferentiable); + } + SLANG_ASSERT(derivativeMemberAttr->args.getCount() == 1); + auto checkedExpr = dispatchExpr(derivativeMemberAttr->args[0], allowStaticReferenceToNonStaticMember()); + if (auto declRefExpr = as<DeclRefExpr>(checkedExpr)) + { + derivativeMemberAttr->memberDeclRef = declRefExpr; + if (!diffType->equals(declRefExpr->type)) + { + getSink()->diagnose(derivativeMemberAttr, Diagnostics::typeMismatch, diffType, declRefExpr->type); + } + if (!varDecl->parentDecl) + { + getSink()->diagnose(derivativeMemberAttr, Diagnostics::attributeNotApplicable, diffType, declRefExpr->type); + } + if (auto memberExpr = as<StaticMemberExpr>(declRefExpr)) + { + auto baseExprType = memberExpr->baseExpression->type.type; + if (auto typeType = as<TypeType>(baseExprType)) + { + if (diffThisType->equals(typeType->type)) + { + return; + } + } + + } + } + getSink()->diagnose( + derivativeMemberAttr, + Diagnostics:: + derivativeMemberAttributeMustNameAMemberInExpectedDifferentialType, + diffThisType); + } + + void SemanticsDeclHeaderVisitor::checkExtensionExternVarAttribute(VarDeclBase* varDecl, ExtensionExternVarModifier* extensionExternMemberModifier) + { + if (auto parentExtension = as<ExtensionDecl>(varDecl->parentDecl)) + { + if (auto originalVarDecl = extensionExternMemberModifier->originalDecl.as<VarDeclBase>()) + { + auto originalType = GetTypeForDeclRef(originalVarDecl, originalVarDecl.getLoc()); + auto extVarType = varDecl->type; + if (!extVarType.type || !extVarType.type->equals(originalType)) + { + getSink()->diagnose(varDecl, Diagnostics::typeOfExternDeclMismatchesOriginalDefinition, varDecl, originalType); + } + else + { + return; + } + } + else + { + getSink()->diagnose(varDecl, Diagnostics::definitionOfExternDeclMismatchesOriginalDefinition, varDecl); + } + } + } + void SemanticsDeclHeaderVisitor::checkVarDeclCommon(VarDeclBase* varDecl) { // A variable that didn't have an explicit type written must @@ -1136,6 +1228,16 @@ namespace Slang getSink()->diagnose(varDecl, Diagnostics::valueRequirementMustBeCompileTimeConst); } } + + // Check modifiers that can't be checked earlier during modifier checking stage. + if (auto derivativeMemberAttr = varDecl->findModifier<DerivativeMemberAttribute>()) + { + checkDerivativeMemberAttribute(varDecl, derivativeMemberAttr); + } + if (auto extensionExternAttr = varDecl->findModifier<ExtensionExternVarModifier>()) + { + checkExtensionExternVarAttribute(varDecl, extensionExternAttr); + } } void SemanticsDeclHeaderVisitor::visitStructDecl(StructDecl* structDecl) diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 745532c27..29b44e726 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -755,7 +755,7 @@ namespace Slang } else if (diffTypeLookupResult.isOverloaded()) { - SLANG_UNIMPLEMENTED_X("Ambiguous differential type declarations not supported"); + getSink()->diagnose(declRefType->declRef, Diagnostics::ambiguousReference, getName("Differential")); } else { @@ -774,7 +774,7 @@ namespace Slang } } - return nullptr; + return m_astBuilder->getErrorType(); } void SemanticsVisitor::maybeRegisterDifferentiableType(ASTBuilder* builder, Type* type) @@ -813,103 +813,6 @@ namespace Slang } } - Expr* SemanticsVisitor::maybeMakeDifferentialExpr(Expr* checkedTerm) - { - // Check that member lookups on differentiable types have appropriate differential - // getters and setters. - if (auto declRefExpr = as<DeclRefExpr>(checkedTerm)) - { - - // Check if we have a parent container. If yes, then checkedTerm is - // referencing a member of this parent. - // - auto parentType = DeclRefType::create(getASTBuilder(), declRefExpr->declRef.getParent()); - - // Check if we have an aggregate (i.e. struct-like) type. - // Ignore interfaces and the case when the term refers to a function - // - if (parentType->declRef.as<AggTypeDeclBase>() && - !parentType->declRef.as<InterfaceDecl>() && - !declRefExpr->declRef.as<CallableDecl>()) - { - // Check if the parent container type is differentiable. - if (auto parentDiffWitness = as<SubtypeWitness>( - tryGetInterfaceConformanceWitness( - parentType, getASTBuilder()->getDifferentiableInterface()))) - { - // If yes, the member in checkedTerm should have a differential getter and setter. - // Otherwise, <ERROR> - // - auto diffExpr = m_astBuilder->create<DifferentiableDeclRefExpr>(); - diffExpr->type = checkedTerm->type; - diffExpr->inner = checkedTerm; - - { - auto getterName = getName("__getDifferentialFor_" + declRefExpr->name->text); - auto getterResult = lookUpMember( - getASTBuilder(), - this, - getterName, - parentType, - Slang::LookupMask::Function, - Slang::LookupOptions::None); - - if (!getterResult.isValid()) - { - // Do nothing.. we assume that this field cannot be differentiated. - // Could this be confusing from a user perspective? - } - else if (getterResult.isOverloaded()) - { - // Diagnose ambiguous getter. - SLANG_UNIMPLEMENTED_X("Ambiguous differential getters not supported"); - } - else - { - auto getterRefExpr = ConstructLookupResultExpr( - getterResult.item, - declRefExpr, - getterResult.item.declRef.getLoc(), - nullptr); - - // Check that the type is what we expect. - // We're going to do this in a very crude way for now. - // Ideally, we want to use the overload resolution and type - // coercion logic in ResolveInvoke() - // - - auto diffType = _getDifferential(m_astBuilder, checkedTerm->type.type); - auto diffParentType = _getDifferential(m_astBuilder, parentType); - - auto ptrDiffType = m_astBuilder->getPtrType(diffType); - auto inoutContainerDiffType = m_astBuilder->getInOutType(diffParentType); - - auto funcType = as<FuncType>(getterRefExpr->type); - - if (!ptrDiffType->equals(funcType->getResultType())) - { - getSink()->diagnose(getterRefExpr, Diagnostics::typeMismatch, - ptrDiffType, funcType->getResultType()); - } - - if (!inoutContainerDiffType->equals(funcType->getParamType(0))) - { - getSink()->diagnose(getterRefExpr, Diagnostics::typeMismatch, - inoutContainerDiffType, funcType->getParamType(0)); - } - - diffExpr->getterExpr = getterRefExpr; - } - } - - return diffExpr; - } - } - } - - return checkedTerm; - } - Expr* SemanticsVisitor::CheckTerm(Expr* term) { auto checkedTerm = _CheckTerm(term); @@ -920,11 +823,6 @@ namespace Slang this->m_parentFunc->findModifier<JVPDerivativeModifier>()) { maybeRegisterDifferentiableType(getASTBuilder(), checkedTerm->type.type); - - if (auto declRefExpr = as<DeclRefExpr>(checkedTerm)) - { - checkedTerm = maybeMakeDifferentialExpr(checkedTerm); - } } return checkedTerm; @@ -1888,14 +1786,6 @@ namespace Slang return expr; } - Expr* SemanticsExprVisitor::visitDifferentiableDeclRefExpr(DifferentiableDeclRefExpr* expr) - { - auto checkedInnerTerm = CheckTerm(expr->inner); - expr->type = checkedInnerTerm->type; - return expr; - } - - Type* SemanticsVisitor::_toDifferentialParamType(ASTBuilder* builder, Type* primalType) { // Check for type modifiers like 'out' and 'inout'. We need to differentiate the @@ -2729,31 +2619,32 @@ namespace Slang // we can return an overloaded result. if (auto overloadedExpr = as<OverloadedExpr>(baseExpr)) { - if (overloadedExpr->base) + // If a member (dynamic or static) lookup result contains both the actual definition + // and the interface definition obtained from inheritance, we want to filter out + // the interface definitions. + LookupResult filteredLookupResult; + for (auto lookupResult : overloadedExpr->lookupResult2) { - // If a member (dynamic or static) lookup result contains both the actual definition - // and the interface definition obtained from inheritance, we want to filter out - // the interface definitions. - LookupResult filteredLookupResult; - for (auto lookupResult : overloadedExpr->lookupResult2) + bool shouldRemove = false; + if (lookupResult.declRef.getParent().as<InterfaceDecl>()) { - bool shouldRemove = false; - if (lookupResult.declRef.getParent().as<InterfaceDecl>()) - shouldRemove = true; - if (!shouldRemove) - { - filteredLookupResult.items.add(lookupResult); - } + shouldRemove = true; + } + if (lookupResult.declRef.getDecl()->hasModifier<ExtensionExternVarModifier>()) + shouldRemove = true; + if (!shouldRemove) + { + filteredLookupResult.items.add(lookupResult); } - if (filteredLookupResult.items.getCount() == 1) - filteredLookupResult.item = filteredLookupResult.items.getFirst(); - baseExpr = createLookupResultExpr( - overloadedExpr->name, - filteredLookupResult, - overloadedExpr->base, - overloadedExpr->loc, - overloadedExpr); } + if (filteredLookupResult.items.getCount() == 1) + filteredLookupResult.item = filteredLookupResult.items.getFirst(); + baseExpr = createLookupResultExpr( + overloadedExpr->name, + filteredLookupResult, + overloadedExpr->base, + overloadedExpr->loc, + overloadedExpr); // TODO: handle other cases of OverloadedExpr that need filtering. } diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 5c1c20e3a..0877f2d6e 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -209,32 +209,10 @@ namespace Slang Substitutions* subst = nullptr; }; - struct LookupRequestKey - { - NodeBase* base; - Name* name; - LookupOptions options; - LookupMask mask; - bool operator==(const LookupRequestKey& other) const - { - return base == other.base && name == other.name && options == other.options && mask == other.mask; - } - HashCode getHashCode() const - { - Hasher hasher; - hasher.hashValue(base); - hasher.hashValue(name); - hasher.hashValue(options); - hasher.hashValue(mask); - return hasher.getResult(); - } - }; - struct TypeCheckingCache { Dictionary<OperatorOverloadCacheKey, OverloadCandidate> resolvedOperatorOverloadCache; Dictionary<BasicTypeKeyPair, ConversionCost> conversionCostCache; - Dictionary<LookupRequestKey, LookupResult> lookupCache; }; struct DifferentiableTypeSemanticContext @@ -305,11 +283,6 @@ namespace Slang bool m_isTypeDictionaryRequired = false; }; - /// Give a cache and a name, will remove all entries associated with a name - /// Might be useful/necessary if a new name is introduced - void removeLookupForName(TypeCheckingCache* cache, Name* name); - - /// Shared state for a semantics-checking session. struct SharedSemanticsContext { @@ -525,6 +498,13 @@ namespace Slang return result; } + SemanticsContext allowStaticReferenceToNonStaticMember() + { + SemanticsContext result(*this); + result.m_allowStaticReferenceToNonStaticMember = true; + return result; + } + private: SharedSemanticsContext* m_shared = nullptr; @@ -545,6 +525,10 @@ namespace Slang /// The type of a try clause (if any) enclosing current expr. TryClauseType m_enclosingTryClauseType = TryClauseType::None; + /// Whether an expr referencing to a non-static member in static style (e.g. `Type.member`) + /// is considered valid in the current context. + bool m_allowStaticReferenceToNonStaticMember = false; + ASTBuilder* m_astBuilder = nullptr; }; @@ -819,11 +803,6 @@ namespace Slang // Check and register a type if it is differentiable. void maybeRegisterDifferentiableType(ASTBuilder* builder, Type* type); - // Check if a term is referencing a member, and add a decoration to it's - // differential getter function, if one exists. - // - Expr* maybeMakeDifferentialExpr(Expr* checkedTerm); - // Construct the differential for 'type', if it exists. Type* _getDifferential(ASTBuilder* builder, Type* type); @@ -1018,7 +997,7 @@ namespace Slang bool getAttributeTargetSyntaxClasses(SyntaxClass<NodeBase> & cls, uint32_t typeFlags); - bool validateAttribute(Attribute* attr, AttributeDecl* attribClassDecl); + bool validateAttribute(Attribute* attr, AttributeDecl* attribClassDecl, ModifiableSyntaxNode* attrTarget); AttributeBase* checkAttribute( UncheckedAttribute* uncheckedAttr, @@ -1924,8 +1903,6 @@ namespace Slang Expr* visitVarExpr(VarExpr *expr); - Expr* visitDifferentiableDeclRefExpr(DifferentiableDeclRefExpr *expr); - Expr* visitTypeCastExpr(TypeCastExpr * expr); Expr* visitTryExpr(TryExpr* expr); diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp index a2b411c22..f977721dd 100644 --- a/source/slang/slang-check-modifier.cpp +++ b/source/slang/slang-check-modifier.cpp @@ -228,17 +228,6 @@ namespace Slang SLANG_ASSERT(!parentDecl->isMemberDictionaryValid()); - // TODO(JS): A bit of a work around(!) - // - // To get to this point we must have already have performed a lookup for attributeName, - // and it failed. That lookup used the TypeCheckingCache, and - // so we know there is a cache entry that will be *wrong*, now we have created and - // added the AttributeDecl with the attributeName. - // - // To work around, we remove all cached lookups around the name, such that when a subsequent - // lookup is made, the cache will not return the old (wrong) result. - removeLookupForName(getLinkage()->getTypeCheckingCache(), attributeName); - // Finally, we perform any required semantic checks on // the newly constructed attribute decl. // @@ -301,7 +290,7 @@ namespace Slang return false; } - bool SemanticsVisitor::validateAttribute(Attribute* attr, AttributeDecl* attribClassDecl) + bool SemanticsVisitor::validateAttribute(Attribute* attr, AttributeDecl* attribClassDecl, ModifiableSyntaxNode* attrTarget) { if(auto numThreadsAttr = as<NumThreadsAttribute>(attr)) { @@ -504,7 +493,6 @@ namespace Slang } else if (auto userDefAttr = as<UserDefinedAttribute>(attr)) { - // check arguments against attribute parameters defined in attribClassDecl Index paramIndex = 0; auto params = attribClassDecl->getMembersOfType<ParamDecl>(); @@ -659,6 +647,15 @@ namespace Slang return false; } } + else if (auto derivativeMemberAttr = as<DerivativeMemberAttribute>(attr)) + { + auto varDecl = as<VarDeclBase>(attrTarget); + if (!varDecl) + { + getSink()->diagnose(attr, Diagnostics::attributeNotApplicable, attr->getKeywordName()); + return false; + } + } else { if(attr->args.getCount() == 0) @@ -784,7 +781,7 @@ namespace Slang } // Now apply type-specific validation to the attribute. - if(!validateAttribute(attr, attrDecl)) + if(!validateAttribute(attr, attrDecl, attrTarget)) { return uncheckedAttr; } @@ -817,7 +814,40 @@ namespace Slang CompletionSuggestions::ScopeKind::HLSLSemantics; } } - + + if (auto externModifier = as<ExternModifier>(m)) + { + if (auto varDecl = as<VarDeclBase>(syntaxNode)) + { + if (auto parentExtension = as<ExtensionDecl>(varDecl->parentDecl)) + { + auto originalMemberLookup = lookUpMember(m_astBuilder, this, varDecl->getName(), parentExtension->targetType); + LookupResult filteredResult; + for (auto item : originalMemberLookup.items) + { + if (item.declRef.getDecl() != varDecl) + AddToLookupResult(filteredResult, item); + } + if (filteredResult.isValid() && !filteredResult.isOverloaded()) + { + auto extensionExternMemberModifier = m_astBuilder->create<ExtensionExternVarModifier>(); + extensionExternMemberModifier->originalDecl = filteredResult.item.declRef; + return extensionExternMemberModifier; + } + else if (filteredResult.isOverloaded()) + { + getSink()->diagnose(varDecl, Diagnostics::ambiguousOriginalDefintionOfExternDecl, varDecl); + } + else + { + getSink()->diagnose(varDecl, Diagnostics::missingOriginalDefintionOfExternDecl, varDecl); + } + } + // The next part of the check is to make sure the type defined here is consistent with the original definition. + // Since we haven't checked the type of this decl yet, we defer that until we have fully checked decl. + // See SemanticsDeclHeaderVisitor::checkExtensionExternVarAttribute. + } + } // Default behavior is to leave things as they are, // and assume that modifiers are mostly already checked. // diff --git a/source/slang/slang-check.cpp b/source/slang/slang-check.cpp index 8c6cddbfe..bcc74a6d0 100644 --- a/source/slang/slang-check.cpp +++ b/source/slang/slang-check.cpp @@ -210,24 +210,4 @@ namespace Slang throw; } } - - void removeLookupForName(TypeCheckingCache* cache, Name* name) - { - auto& lookupCache = cache->lookupCache; - - List<LookupRequestKey> keys; - - for (const auto& pairs : lookupCache) - { - const auto& key = pairs.Key; - if (key.name == name) - { - keys.add(key); - } - } - for (auto& key : keys) - { - lookupCache.Remove(key); - } - } } diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index d7e56309a..6e6a6f5e5 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -329,19 +329,26 @@ DIAGNOSTIC(31123, Error, invalidGUID, "'$0' is not a valid GUID") DIAGNOSTIC(31124, Error, structCannotImplementComInterface, "a struct type cannot implement a [COM] interface") DIAGNOSTIC(31124, Error, interfaceInheritingComMustBeCom, "an interface type that inherits from a [COM] interface must itself be a [COM] interface") +DIAGNOSTIC(31130, Error, derivativeMemberAttributeMustNameAMemberInExpectedDifferentialType, "[DerivativeMember] must reference to a member in the associated differential type '$0'.") +DIAGNOSTIC(31131, Error, invalidUseOfDerivativeMemberAttributeParentTypeIsNotDifferentiable, "invalid use of [DerivativeMember], parent type is not differentiable.") +DIAGNOSTIC(31132, Error, derivativeMemberAttributeCanOnlyBeUsedOnMembers, "[DerivativeMember] is allowed on members only.") + +DIAGNOSTIC(31140, Error, typeOfExternDeclMismatchesOriginalDefinition, "type of `extern` decl '$0' differs from its original definition. expected '$1'.") +DIAGNOSTIC(31141, Error, definitionOfExternDeclMismatchesOriginalDefinition, "`extern` decl '$0' is not consistent with its original definition.") +DIAGNOSTIC(31142, Error, ambiguousOriginalDefintionOfExternDecl, "`extern` decl '$0' has ambiguous original definitions.") +DIAGNOSTIC(31143, Error, missingOriginalDefintionOfExternDecl, "no original definition found for `extern` decl '$0'.") // Enums DIAGNOSTIC(32000, Error, invalidEnumTagType, "invalid tag type for 'enum': '$0'") DIAGNOSTIC(32003, Error, unexpectedEnumTagExpr, "unexpected form for 'enum' tag value expression") - - // 303xx: interfaces and associated types DIAGNOSTIC(30300, Error, assocTypeInInterfaceOnly, "'associatedtype' can only be defined in an 'interface'.") DIAGNOSTIC(30301, Error, globalGenParamInGlobalScopeOnly, "'type_param' can only be defined global scope.") DIAGNOSTIC(30302, Error, staticConstRequirementMustBeIntOrBool, "'static const' requirement can only have int or bool type.") DIAGNOSTIC(30303, Error, valueRequirementMustBeCompileTimeConst, "requirement in the form of a simple value must be declared as 'static const'.") +DIAGNOSTIC(30310, Error, typeIsNotDifferentiable, "type '$0' is not differentiable.") // Interop DIAGNOSTIC(30400, Error, cannotDefinePtrTypeToManagedResource, "pointer to a managed resource is invalid, use `NativeRef<T>` instead") diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index 1ea54475e..bab33e79d 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -384,6 +384,8 @@ Result linkAndOptimizeIR( // 3. Fill in higher-order invocations with the generated functions. processDerivativeCalls(irModule); + stripAutoDiffDecorations(irModule); + dumpIRIfEnabled(codeGenContext, irModule, "AFTER-AUTODIFF"); validateIRModuleIfEnabled(codeGenContext, irModule); diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp index 843428c01..b97556ab1 100644 --- a/source/slang/slang-ir-diff-jvp.cpp +++ b/source/slang/slang-ir-diff-jvp.cpp @@ -115,7 +115,7 @@ struct DifferentiableTypeConformanceContext IRInst* lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key) { - if (auto conformance = lookUpConformanceForType(builder, origType)) + if (auto conformance = lookUpConformanceForType(builder, origType)) { if (auto witnessTable = as<IRWitnessTable>(conformance)) { @@ -144,6 +144,14 @@ struct DifferentiableTypeConformanceContext // IRInst* getDifferentialForType(IRBuilder* builder, IRType* origType) { + switch (origType->getOp()) + { + case kIROp_FloatType: + case kIROp_HalfType: + case kIROp_DoubleType: + case kIROp_VectorType: + return origType; + } return lookUpInterfaceMethod(builder, origType, differentialAssocTypeStructKey); } @@ -1083,8 +1091,7 @@ struct JVPTranscriber // in the current transcription context. // InstPair transcribeCall(IRBuilder* builder, IRCall* origCall) - { - + { if (as<IRFunc>(origCall->getCallee())) { auto origCallee = origCall->getCallee(); @@ -1094,12 +1101,28 @@ struct JVPTranscriber // auto primalCallee = origCallee; - // TODO: If inner is not differentiable, treat as non-differentiable call. - // Build the differential callee - IRInst* diffCall = builder->emitJVPDifferentiateInst( - differentiateFunctionType(builder, as<IRFuncType>(primalCallee->getFullType())), - primalCallee); - + IRInst* diffCallee = nullptr; + + if (auto derivativeReferenceDecor = primalCallee->findDecoration<IRJVPDerivativeReferenceDecoration>()) + { + // If the user has already provided an differentiated implementation, use that. + diffCallee = derivativeReferenceDecor->getJVPFunc(); + } + else if (primalCallee->findDecoration<IRJVPDerivativeMarkerDecoration>()) + { + // If the function is marked for auto-diff, push a `differentiate` inst for a follow up pass + // to generate the implementation. + diffCallee = builder->emitJVPDifferentiateInst( + differentiateFunctionType(builder, as<IRFuncType>(primalCallee->getFullType())), + primalCallee); + } + else + { + // The callee is non differentiable, just return primal value with null diff value. + IRInst* primalCall = cloneInst(&cloneEnv, builder, origCall); + return InstPair(primalCall, nullptr); + } + List<IRInst*> args; // Go over the parameter list and create pairs for each input (if required) for (UIndex ii = 0; ii < origCall->getArgCount(); ii++) @@ -1109,18 +1132,16 @@ struct JVPTranscriber SLANG_ASSERT(primalArg); auto primalType = primalArg->getDataType(); + auto diffArg = findOrTranscribeDiffInst(builder, origArg); + + if (!diffArg) + diffArg = getDifferentialZeroOfType(builder, primalType); + if (auto pairType = tryGetDiffPairType(builder, primalType)) { - auto diffArg = findOrTranscribeDiffInst(builder, origArg); - - if (!diffArg) - diffArg = getDifferentialZeroOfType(builder, primalType); - // If a pair type can be formed, this must be non-null. SLANG_RELEASE_ASSERT(diffArg); - auto diffPair = builder->emitMakeDifferentialPair(pairType, primalArg, diffArg); - args.add(diffPair); } else @@ -1130,17 +1151,19 @@ struct JVPTranscriber } } - auto diffReturnType = tryGetDiffPairType(builder, origCall->getFullType()); + IRType* diffReturnType = nullptr; + diffReturnType = tryGetDiffPairType(builder, origCall->getFullType()); SLANG_ASSERT(diffReturnType); auto callInst = builder->emitCallInst( diffReturnType, - diffCall, + diffCallee, args); + + IRInst* primalResultValue = pairBuilder->emitPrimalFieldAccess(builder, callInst); + IRInst* diffResultValue = pairBuilder->emitDiffFieldAccess(builder, callInst); - return InstPair( - pairBuilder->emitPrimalFieldAccess(builder, callInst), - pairBuilder->emitDiffFieldAccess(builder, callInst)); + return InstPair(primalResultValue, diffResultValue); } else if(as<IRSpecialize>(origCall->getCallee()) || as<IRLookupWitnessMethod>(origCall->getCallee())) @@ -1396,89 +1419,45 @@ struct JVPTranscriber return InstPair(diffBlock, diffBlock); } - InstPair transcribeFieldExtract(IRBuilder* builder, IRFieldExtract* origExtract) + InstPair transcribeFieldExtract(IRBuilder* builder, IRInst* originalInst) { - IRInst* origBase = origExtract->getBase(); + SLANG_ASSERT(as<IRFieldExtract>(originalInst) || as<IRFieldAddress>(originalInst)); + + IRInst* origBase = originalInst->getOperand(0); auto primalBase = findOrTranscribePrimalInst(builder, origBase); - auto diffBase = findOrTranscribeDiffInst(builder, origBase); + auto field = originalInst->getOperand(1); + auto derivativeRefDecor = field->findDecoration<IRJVPDerivativeMemberReferenceDecoration>(); + auto primalType = (IRType*)lookupPrimalInst(originalInst->getDataType(), originalInst->getDataType()); - auto primalExtractType = (IRType*)lookupPrimalInst(origExtract->getDataType(), origExtract->getDataType()); - - IRInst* primalExtract = builder->emitFieldExtract(primalExtractType, primalBase, origExtract->getField()); - IRInst* diffExtract = nullptr; + IRInst* primalOperands[] = { primalBase, field }; + IRInst* primalFieldExtract = builder->emitIntrinsicInst( + primalType, + originalInst->getOp(), + 2, + primalOperands); - if (auto diffExtractType = differentiateType(builder, primalExtractType)) + if (!derivativeRefDecor) { - // Check if we have a getter. - if (auto getterDecoration = origExtract->findDecoration<IRDifferentialGetterDecoration>()) - { - - IRInst* getterFunc = getterDecoration->getGetterFunc(); - - // Must be a method with a single parameter. - SLANG_ASSERT(as<IRFuncType>(getterFunc->getDataType())->getParamCount() == 1); - - // Our getter func accepts a _pointer_ to the target type - // So we have to create a variable and store our type into memory - // here. This will eventually get optimized out in later passes. - // - auto diffTempVar = builder->emitVar( - diffBase->getDataType()); - - builder->emitStore(diffTempVar, diffBase); - - List<IRInst*> args; - args.add(diffTempVar); - - // Emit a call to the getter. The getter will return a reference type. - // We need to load from this to go to a non-ptr 'solid' type. - // - auto diffGetterCall = builder->emitCallInst( - as<IRFuncType>(getterFunc->getDataType())->getResultType(), - getterFunc, - args); - - diffExtract = builder->emitLoad(diffGetterCall); - } + return InstPair(primalFieldExtract, nullptr); } - return InstPair(primalExtract, diffExtract); - } - - InstPair transcribeFieldAddress(IRBuilder* builder, IRFieldAddress* origAddress) - { - IRInst* origBase = origAddress->getBase(); - auto primalBase = findOrTranscribePrimalInst(builder, origBase); - auto diffBase = findOrTranscribeDiffInst(builder, origBase); - - auto primalAddressType = (IRType*)lookupPrimalInst(origAddress->getDataType(), origAddress->getDataType()); + IRInst* diffFieldExtract = nullptr; - IRInst* primalAddress = builder->emitFieldAddress(primalAddressType, primalBase, origAddress->getField()); - IRInst* diffAddress = nullptr; - - if (auto diffAddressType = differentiateType(builder, primalAddressType)) + if (auto diffType = differentiateType(builder, primalType)) { - // If we have a getter associated with this field, we want to use that. - if (auto getterDecoration = origAddress->findDecoration<IRDifferentialGetterDecoration>()) + if (auto diffBase = findOrTranscribeDiffInst(builder, origBase)) { - auto getterFunc = getterDecoration->getGetterFunc(); - - // Add the base differential inst as the argument. - List<IRInst*> args; - args.add(diffBase); - - diffAddress = builder->emitCallInst( - as<IRFuncType>(getterFunc->getDataType())->getResultType(), - getterFunc, - args); + IRInst* diffOperands[] = { diffBase, derivativeRefDecor->getDerivativeMemberStructKey() }; + diffFieldExtract = builder->emitIntrinsicInst( + diffType, + originalInst->getOp(), + 2, + diffOperands); } - } - - return InstPair(primalAddress, diffAddress); + return InstPair(primalFieldExtract, diffFieldExtract); } - InstPair transcribeGetElement(IRBuilder* builder, IRInst* origGetElementPtr) { SLANG_ASSERT(as<IRGetElement>(origGetElementPtr) || as<IRGetElementPtr>(origGetElementPtr)); @@ -1514,7 +1493,6 @@ struct JVPTranscriber return InstPair(primalGetElementPtr, diffGetElementPtr); } - InstPair transcribeLoop(IRBuilder* builder, IRLoop* origLoop) { // The loop comes with three blocks.. we just need to transcribe each one @@ -1640,9 +1618,13 @@ struct JVPTranscriber as<IRFuncType>(origFunc->getFullType())); diffFunc->setFullType(diffFuncType); - // TODO(sai): Replace naming scheme - // if (auto jvpName = this->getJVPFuncName(builder, primalFn)) - // builder->addNameHintDecoration(diffFunc, jvpName); + if (auto nameHint = origFunc->findDecoration<IRNameHintDecoration>()) + { + auto originalName = nameHint->getName(); + StringBuilder newNameSb; + newNameSb << "s_jvp_" << originalName; + builder->addNameHintDecoration(diffFunc, newNameSb.getUnownedSlice()); + } // Transcribe children from origFunc into diffFunc builder->setInsertInto(diffFunc); @@ -1719,9 +1701,18 @@ struct JVPTranscriber { mapPrimalInst(origInst, pair.primal); mapDifferentialInst(origInst, pair.differential); + if (pair.differential) + { + // Generate name hint for the inst. + if (auto primalNameHint = primalInst->findDecoration<IRNameHintDecoration>()) + { + StringBuilder sb; + sb << "s_diff_" << primalNameHint->getName(); + builder->addNameHintDecoration(pair.differential, sb.getUnownedSlice()); + } + } return pair.differential; } - instsInProgress.Remove(origInst); getSink()->diagnose(origInst->sourceLoc, @@ -1789,16 +1780,14 @@ struct JVPTranscriber getSink()->diagnose(origInst->sourceLoc, Diagnostics::unexpected, "should not be attempting to differentiate anything specialized here."); + return InstPair(nullptr, nullptr); case kIROp_lookup_interface_method: return transcibeLookupInterfaceMethod(builder, as<IRLookupWitnessMethod>(origInst)); case kIROp_FieldExtract: - return transcribeFieldExtract(builder, as<IRFieldExtract>(origInst)); - case kIROp_FieldAddress: - return transcribeFieldAddress(builder, as<IRFieldAddress>(origInst)); - + return transcribeFieldExtract(builder, origInst); case kIROp_getElement: case kIROp_getElementPtr: return transcribeGetElement(builder, origInst); @@ -1942,11 +1931,6 @@ struct JVPDerivativeContext // Temporary fix: Move generated types, if any, to before their use locations. (&pairBuilderStorage)->relocateNewTypes(builder); - // Remove all kIROp_DifferentiableTypeDictionary instructions and - // kIROp_DifferentialGetterDecoration decorations - // - modified |= stripDiffTypeInformation(builder, module->getModuleInst()); - return modified; } @@ -1954,7 +1938,6 @@ struct JVPDerivativeContext { if(auto jvpDefinition = primalFunction->findDecoration<IRJVPDerivativeReferenceDecoration>()) return jvpDefinition->getJVPFunc(); - return nullptr; } @@ -2166,7 +2149,7 @@ struct JVPDerivativeContext return modified; } - bool stripDiffTypeInformation(IRBuilder* builder, IRInst* parent) + bool stripDiffTypeInformation(IRInst* parent) { bool modified = false; @@ -2175,22 +2158,18 @@ struct JVPDerivativeContext { auto nextChild = child->getNextInst(); - if (child->getOp() == kIROp_DifferentiableTypeDictionary) + switch (child->getOp()) { + case kIROp_DifferentiableTypeDictionary: child->removeAndDeallocate(); child = nextChild; modified = true; continue; } - if (auto getterDecoration = child->findDecoration<IRDifferentialGetterDecoration>()) - { - getterDecoration->removeAndDeallocate(); - } - if (child->getFirstChild() != nullptr) { - modified |= stripDiffTypeInformation(builder, child); + modified |= stripDiffTypeInformation(child); } child = nextChild; @@ -2311,8 +2290,30 @@ bool processJVPDerivativeMarkers( eliminateDeadCode(module, options); JVPDerivativeContext context(module, sink); + bool changed = context.processModule(); + changed |= context.stripDiffTypeInformation(module->getModuleInst()); + return changed; +} - return context.processModule(); +void stripAutoDiffDecorations(IRModule* module) +{ + for (auto inst : module->getGlobalInsts()) + { + for (auto decor = inst->getFirstDecoration(); decor; ) + { + auto next = decor->getNextDecoration(); + switch (decor->getOp()) + { + case kIROp_JVPDerivativeReferenceDecoration: + case kIROp_JVPDerivativeMemberReferenceDecoration: + decor->removeAndDeallocate(); + break; + default: + break; + } + decor = next; + } + } } } diff --git a/source/slang/slang-ir-diff-jvp.h b/source/slang/slang-ir-diff-jvp.h index 8ae6e949a..8ab4e0e8f 100644 --- a/source/slang/slang-ir-diff-jvp.h +++ b/source/slang/slang-ir-diff-jvp.h @@ -18,4 +18,5 @@ namespace Slang DiagnosticSink* sink, IRJVPDerivativePassOptions const& options = IRJVPDerivativePassOptions()); + void stripAutoDiffDecorations(IRModule* module); } diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index f91fc9cda..c59286116 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -707,8 +707,8 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) INST(JVPDerivativeReferenceDecoration, jvpFnReference, 1, 0) /// Used by the auto-diff pass to hold a reference to a - /// differential getter associated with this expression. - INST(DifferentialGetterDecoration, diffGetter, 1, 0) + /// differential member of a type in its associated differential type. + INST(JVPDerivativeMemberReferenceDecoration, derivativeMemberDecoration, 1, 0) /// Marks a class type as a COM interface implementation, which enables /// the witness table to be easily picked up by emit. diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 33a2fbfb0..5a9c14038 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -546,6 +546,15 @@ struct IRSequentialIDDecoration : IRDecoration IRIntegerValue getSequentialID() { return getSequentialIDOperand()->getValue(); } }; +struct IRJVPDerivativeMarkerDecoration : IRDecoration +{ + enum + { + kOp = kIROp_JVPDerivativeMarkerDecoration + }; + IR_LEAF_ISA(JVPDerivativeMarkerDecoration) +}; + struct IRJVPDerivativeReferenceDecoration : IRDecoration { enum @@ -557,15 +566,15 @@ struct IRJVPDerivativeReferenceDecoration : IRDecoration IRInst* getJVPFunc() { return getOperand(0); } }; -struct IRDifferentialGetterDecoration : IRDecoration +struct IRJVPDerivativeMemberReferenceDecoration : IRDecoration { enum { - kOp = kIROp_DifferentialGetterDecoration + kOp = kIROp_JVPDerivativeMemberReferenceDecoration }; - IR_LEAF_ISA(DifferentialGetterDecoration) + IR_LEAF_ISA(JVPDerivativeMemberReferenceDecoration) - IRInst* getGetterFunc() { return getOperand(0); } + IRInst* getDerivativeMemberStructKey() { return getOperand(0); } }; // An instruction that replaces the function symbol @@ -3192,6 +3201,11 @@ public: addDecoration(value, kIROp_ExternCppDecoration, getStringValue(mangledName)); } + void addForceInlineDecoration(IRInst* value) + { + addDecoration(value, kIROp_ForceInlineDecoration); + } + void addJVPDerivativeMarkerDecoration(IRInst* value) { addDecoration(value, kIROp_JVPDerivativeMarkerDecoration); @@ -3202,11 +3216,6 @@ public: addDecoration(value, kIROp_JVPDerivativeReferenceDecoration, jvpFn); } - void addDifferentialGetterDecoration(IRInst* value, IRInst* getterFn) - { - addDecoration(value, kIROp_DifferentialGetterDecoration, getterFn); - } - void addCOMWitnessDecoration(IRInst* value, IRInst* witnessTable) { addDecoration(value, kIROp_COMWitnessDecoration, &witnessTable, 1); diff --git a/source/slang/slang-lookup.cpp b/source/slang/slang-lookup.cpp index cddf3d7ce..c574be4ea 100644 --- a/source/slang/slang-lookup.cpp +++ b/source/slang/slang-lookup.cpp @@ -89,6 +89,18 @@ void buildMemberDictionary(ContainerDecl* decl) bool DeclPassesLookupMask(Decl* decl, LookupMask mask) { + // Always exclude extern members from lookup result. + if (decl->hasModifier<ExtensionExternVarModifier>()) + { + return false; + } + else if (decl->hasModifier<ExternModifier>()) + { + if (as<ExtensionDecl>(decl->parentDecl)) + { + return false; + } + } // type declarations if(auto aggTypeDecl = as<AggTypeDecl>(decl)) { @@ -108,7 +120,7 @@ bool DeclPassesLookupMask(Decl* decl, LookupMask mask) { return (int(mask) & int(LookupMask::Attribute)) != 0; } - + // default behavior is to assume a value declaration // (no overloading allowed) @@ -942,7 +954,7 @@ static void _lookUpInScopes( // The implicit `this`/`This` for a function-like declaration // depends on modifiers attached to the declaration. // - if (funcDeclRef.getDecl()->hasModifier<HLSLStaticModifier>()) + if (isEffectivelyStatic(funcDeclRef.getDecl())) { // A `static` method only has access to an implicit `This`, // and does not have a `this` expression available. @@ -1002,26 +1014,8 @@ LookupResult lookUp( LookupMask mask) { LookupResult result; - LookupRequestKey key; - TypeCheckingCache* typeCheckingCache = nullptr; - if (semantics) - { - typeCheckingCache = semantics->getLinkage()->getTypeCheckingCache(); - key.base = scope; - key.name = name; - key.options = LookupOptions::None; - key.mask = mask; - if (typeCheckingCache->lookupCache.TryGetValue(key, result)) - { - return result; - } - } LookupRequest request = initLookupRequest(semantics, name, mask, LookupOptions::None, scope); _lookUpInScopes(astBuilder, name, request, result); - if (typeCheckingCache) - { - typeCheckingCache->lookupCache[key] = result; - } return result; } @@ -1033,20 +1027,9 @@ LookupResult lookUpMember( LookupMask mask, LookupOptions options) { - TypeCheckingCache* typeCheckingCache = semantics->getLinkage()->getTypeCheckingCache(); - LookupRequestKey key; - key.base = type; - key.name = name; - key.options = options; - key.mask = mask; LookupResult result; - if (typeCheckingCache->lookupCache.TryGetValue(key, result)) - { - return result; - } LookupRequest request = initLookupRequest(semantics, name, mask, options, nullptr); _lookUpMembersInType(astBuilder, name, type, request, result, nullptr); - typeCheckingCache->lookupCache[key] = result; return result; } diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index dc6067868..1e58a456e 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -3038,38 +3038,6 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> return info; } - LoweredValInfo visitDifferentiableDeclRefExpr(DifferentiableDeclRefExpr* expr) - { - LoweredValInfo info = lowerSubExpr(expr->inner); - - IRInst* irBaseVal = nullptr; - switch (info.flavor) - { - case LoweredValInfo::Flavor::Simple: - irBaseVal = getSimpleVal(context, info); - break; - - case LoweredValInfo::Flavor::Ptr: - irBaseVal = info.val; - break; - - default: - SLANG_UNEXPECTED("Unhandled lowered value cases"); - } - - // If the differentiable expr has an associated getter or setter, lower it - // and put it in a decoration. - // - if (expr->getterExpr != nullptr) - { - auto irGetter = lowerSubExpr(expr->getterExpr); - SLANG_ASSERT(irGetter.flavor == LoweredValInfo::Flavor::Simple); - getBuilder()->addDifferentialGetterDecoration(irBaseVal, irGetter.val); - } - - return info; - } - // Emit IR to denote the forward-mode derivative // of the inner func-expr. This will be resolved // to a concrete function during the derivative @@ -6319,7 +6287,13 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // A variable declared inside of an aggregate type declaration is a member. return true; } - + if (auto extDecl = as<ExtensionDecl>(parent)) + { + if (auto declRefType = as<DeclRefType>(extDecl->targetType.type)) + { + return true; + } + } return false; } @@ -7108,6 +7082,14 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> builder->addDecoration(inst, op, operands.getBuffer(), operands.getCount()); } + void lowerDerivativeMemberModifier(IRInst* inst, DerivativeMemberAttribute* derivativeMember) + { + auto key = lowerRValueExpr(context, derivativeMember->memberDeclRef).val; + SLANG_RELEASE_ASSERT(as<IRStructKey>(key)); + auto builder = getBuilder(); + builder->addDecoration(inst, kIROp_JVPDerivativeMemberReferenceDecoration, key); + } + LoweredValInfo lowerMemberVarDecl(VarDecl* fieldDecl) { // Each field declaration in the AST translates into @@ -7120,12 +7102,21 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // will use the same space of keys. auto builder = getBuilder(); - auto irFieldKey = builder->createStructKey(); - addNameHint(context, irFieldKey, fieldDecl); + IRInst* irFieldKey = nullptr; + if (auto extVarModifier = fieldDecl->findModifier<ExtensionExternVarModifier>()) + { + irFieldKey = ensureDecl(context, extVarModifier->originalDecl.getDecl()).val; + SLANG_RELEASE_ASSERT(as<IRStructKey>(irFieldKey)); + } - addVarDecorations(context, irFieldKey, fieldDecl); + if (!irFieldKey) + { + irFieldKey = builder->createStructKey(); - addLinkageDecoration(context, irFieldKey, fieldDecl); + addNameHint(context, irFieldKey, fieldDecl); + addVarDecorations(context, irFieldKey, fieldDecl); + addLinkageDecoration(context, irFieldKey, fieldDecl); + } if (auto semanticModifier = fieldDecl->findModifier<HLSLSimpleSemantic>()) { @@ -7140,6 +7131,10 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> { lowerRayPayloadAccessModifier(irFieldKey, writeModifier, kIROp_StageWriteAccessDecoration); } + if (auto derivativeMemberModifier = fieldDecl->findModifier<DerivativeMemberAttribute>()) + { + lowerDerivativeMemberModifier(irFieldKey, derivativeMemberModifier); + } // We allow a field to be marked as a target intrinsic, // so that we can override its mangled name in the @@ -7815,6 +7810,16 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> getBuilder()->addJVPDerivativeMarkerDecoration(irFunc); } + // Always force inline diff setter accessor to prevent downstream compiler from complaining + // fields are not fully initialized for the first `inout` parameter. + if (as<SetterDecl>(decl)) + { + if (!decl->findModifier<ForceInlineAttribute>()) + { + getBuilder()->addForceInlineDecoration(irFunc); + } + } + FuncDeclBaseTypeInfo info; _lowerFuncDeclBaseTypeInfo( subContext, |
