summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
Diffstat (limited to 'source')
-rw-r--r--source/slang/core.meta.slang6
-rw-r--r--source/slang/slang-ast-decl.h1
-rw-r--r--source/slang/slang-ast-expr.h12
-rw-r--r--source/slang/slang-ast-modifier.h21
-rw-r--r--source/slang/slang-ast-support-types.h1
-rw-r--r--source/slang/slang-check-decl.cpp104
-rw-r--r--source/slang/slang-check-expr.cpp157
-rw-r--r--source/slang/slang-check-impl.h47
-rw-r--r--source/slang/slang-check-modifier.cpp60
-rw-r--r--source/slang/slang-check.cpp20
-rw-r--r--source/slang/slang-diagnostic-defs.h11
-rw-r--r--source/slang/slang-emit.cpp2
-rw-r--r--source/slang/slang-ir-diff-jvp.cpp229
-rw-r--r--source/slang/slang-ir-diff-jvp.h1
-rw-r--r--source/slang/slang-ir-inst-defs.h4
-rw-r--r--source/slang/slang-ir-insts.h27
-rw-r--r--source/slang/slang-lookup.cpp45
-rw-r--r--source/slang/slang-lower-to-ir.cpp79
18 files changed, 414 insertions, 413 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index 1711102da..769a1091d 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -2737,6 +2737,12 @@ __attributeTarget(InterfaceDecl)
attribute_syntax [Specialize] : SpecializeAttribute;
__attributeTarget(DeclBase)
+attribute_syntax [Differentiable] : DifferentiableAttribute;
+
+__attributeTarget(DeclBase)
+attribute_syntax [DerivativeMember(memberName)] : DerivativeMemberAttribute;
+
+__attributeTarget(DeclBase)
attribute_syntax [builtin] : BuiltinAttribute;
__attributeTarget(DeclBase)
diff --git a/source/slang/slang-ast-decl.h b/source/slang/slang-ast-decl.h
index 07cfe6a0c..b1b20dc93 100644
--- a/source/slang/slang-ast-decl.h
+++ b/source/slang/slang-ast-decl.h
@@ -337,7 +337,6 @@ class RefAccessorDecl : public AccessorDecl
{
SLANG_AST_CLASS(RefAccessorDecl)
};
-
class FuncDecl : public FunctionDeclBase
{
SLANG_AST_CLASS(FuncDecl)
diff --git a/source/slang/slang-ast-expr.h b/source/slang/slang-ast-expr.h
index e0a55cc29..baa6de73a 100644
--- a/source/slang/slang-ast-expr.h
+++ b/source/slang/slang-ast-expr.h
@@ -38,18 +38,6 @@ class VarExpr : public DeclRefExpr
SLANG_AST_CLASS(VarExpr)
};
-class DifferentiableDeclRefExpr : public Expr
-{
- SLANG_AST_CLASS(DifferentiableDeclRefExpr)
-
- // Inner decl ref expr that references a differentiable expression.
- Expr* inner = nullptr;
-
- // Information on getters and setters if available.
- Expr* setterExpr = nullptr;
- Expr* getterExpr = nullptr;
-};
-
// An expression that references an overloaded set of declarations
// having the same name.
class OverloadedExpr : public Expr
diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h
index 8230f481e..b019953cb 100644
--- a/source/slang/slang-ast-modifier.h
+++ b/source/slang/slang-ast-modifier.h
@@ -32,6 +32,14 @@ class GloballyCoherentModifier : public Modifier { SLANG_AST_CLASS(GloballyCoher
class ExternCppModifier : public Modifier { SLANG_AST_CLASS(ExternCppModifier)};
class JVPDerivativeModifier : public Modifier { SLANG_AST_CLASS(JVPDerivativeModifier)};
+// An `extern` variable in an extension is used to introduce additional attributes on an existing
+// field.
+class ExtensionExternVarModifier : public Modifier
+{
+ SLANG_AST_CLASS(ExtensionExternVarModifier)
+ DeclRef<Decl> originalDecl;
+};
+
// An 'ActualGlobal' is a global that is output as a normal global in CPU code.
// Globals in HLSL/Slang are constant state passed into kernel execution
class ActualGlobalModifier : public Modifier { SLANG_AST_CLASS(ActualGlobalModifier)};
@@ -951,6 +959,12 @@ class SpecializeAttribute : public Attribute
SLANG_AST_CLASS(SpecializeAttribute)
};
+ /// An attribute that marks a type, function or variable as differentiable.
+class DifferentiableAttribute : public Attribute
+{
+ SLANG_AST_CLASS(DifferentiableAttribute)
+};
+
class DllImportAttribute : public Attribute
{
SLANG_AST_CLASS(DllImportAttribute)
@@ -965,6 +979,13 @@ class DllExportAttribute : public Attribute
SLANG_AST_CLASS(DllExportAttribute)
};
+class DerivativeMemberAttribute : public Attribute
+{
+ SLANG_AST_CLASS(DerivativeMemberAttribute)
+
+ DeclRefExpr* memberDeclRef;
+};
+
/// An attribute that marks an interface type as a COM interface declaration.
class ComInterfaceAttribute : public Attribute
{
diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h
index d6f9a305b..39ca71267 100644
--- a/source/slang/slang-ast-support-types.h
+++ b/source/slang/slang-ast-support-types.h
@@ -1115,7 +1115,6 @@ namespace Slang
Function = 0x2,
Value = 0x4,
Attribute = 0x8,
-
Default = type | Function | Value,
};
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 2d6e20622..356105e4f 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -45,7 +45,10 @@ namespace Slang
void visitDecl(Decl*) {}
void visitDeclGroup(DeclGroup*) {}
-
+
+ void checkDerivativeMemberAttribute(VarDeclBase* varDecl, DerivativeMemberAttribute* attr);
+ void checkExtensionExternVarAttribute(VarDeclBase* varDecl, ExtensionExternVarModifier* m);
+
void checkVarDeclCommon(VarDeclBase* varDecl);
void visitVarDecl(VarDecl* varDecl)
@@ -78,6 +81,8 @@ namespace Slang
void checkCallableDeclCommon(CallableDecl* decl);
+ void maybeCheckDifferentiableAccessorSignature(FuncDecl* funcDecl);
+
void visitFuncDecl(FuncDecl* funcDecl);
void visitParamDecl(ParamDecl* paramDecl);
@@ -636,6 +641,9 @@ namespace Slang
bool SemanticsVisitor::isDeclUsableAsStaticMember(
Decl* decl)
{
+ if (m_allowStaticReferenceToNonStaticMember)
+ return true;
+
if(auto genericDecl = as<GenericDecl>(decl))
decl = genericDecl->inner;
@@ -663,6 +671,9 @@ namespace Slang
bool SemanticsVisitor::isUsableAsStaticMember(
LookupResultItem const& item)
{
+ if (m_allowStaticReferenceToNonStaticMember)
+ return true;
+
// There's a bit of a gotcha here, because a lookup result
// item might include "breadcrumbs" that indicate more steps
// along the lookup path. As a result it isn't always
@@ -966,6 +977,87 @@ namespace Slang
tryConstantFoldDeclRef(DeclRef<VarDeclBase>(varDecl, nullptr), nullptr);
}
+ void SemanticsDeclHeaderVisitor::checkDerivativeMemberAttribute(
+ VarDeclBase* varDecl, DerivativeMemberAttribute* derivativeMemberAttr)
+ {
+ auto memberType = checkProperType(getLinkage(), varDecl->type, getSink());
+ auto diffType = _getDifferential(m_astBuilder, memberType);
+ if (as<ErrorType>(diffType))
+ {
+ getSink()->diagnose(derivativeMemberAttr, Diagnostics::typeIsNotDifferentiable, memberType);
+ }
+ auto thisType = calcThisType(makeDeclRef(varDecl->parentDecl));
+ if (!thisType)
+ {
+ getSink()->diagnose(
+ derivativeMemberAttr,
+ Diagnostics::
+ derivativeMemberAttributeCanOnlyBeUsedOnMembers);
+ }
+ auto diffThisType = _getDifferential(m_astBuilder, thisType);
+ if (!thisType)
+ {
+ getSink()->diagnose(
+ derivativeMemberAttr,
+ Diagnostics::invalidUseOfDerivativeMemberAttributeParentTypeIsNotDifferentiable);
+ }
+ SLANG_ASSERT(derivativeMemberAttr->args.getCount() == 1);
+ auto checkedExpr = dispatchExpr(derivativeMemberAttr->args[0], allowStaticReferenceToNonStaticMember());
+ if (auto declRefExpr = as<DeclRefExpr>(checkedExpr))
+ {
+ derivativeMemberAttr->memberDeclRef = declRefExpr;
+ if (!diffType->equals(declRefExpr->type))
+ {
+ getSink()->diagnose(derivativeMemberAttr, Diagnostics::typeMismatch, diffType, declRefExpr->type);
+ }
+ if (!varDecl->parentDecl)
+ {
+ getSink()->diagnose(derivativeMemberAttr, Diagnostics::attributeNotApplicable, diffType, declRefExpr->type);
+ }
+ if (auto memberExpr = as<StaticMemberExpr>(declRefExpr))
+ {
+ auto baseExprType = memberExpr->baseExpression->type.type;
+ if (auto typeType = as<TypeType>(baseExprType))
+ {
+ if (diffThisType->equals(typeType->type))
+ {
+ return;
+ }
+ }
+
+ }
+ }
+ getSink()->diagnose(
+ derivativeMemberAttr,
+ Diagnostics::
+ derivativeMemberAttributeMustNameAMemberInExpectedDifferentialType,
+ diffThisType);
+ }
+
+ void SemanticsDeclHeaderVisitor::checkExtensionExternVarAttribute(VarDeclBase* varDecl, ExtensionExternVarModifier* extensionExternMemberModifier)
+ {
+ if (auto parentExtension = as<ExtensionDecl>(varDecl->parentDecl))
+ {
+ if (auto originalVarDecl = extensionExternMemberModifier->originalDecl.as<VarDeclBase>())
+ {
+ auto originalType = GetTypeForDeclRef(originalVarDecl, originalVarDecl.getLoc());
+ auto extVarType = varDecl->type;
+ if (!extVarType.type || !extVarType.type->equals(originalType))
+ {
+ getSink()->diagnose(varDecl, Diagnostics::typeOfExternDeclMismatchesOriginalDefinition, varDecl, originalType);
+ }
+ else
+ {
+ return;
+ }
+ }
+ else
+ {
+ getSink()->diagnose(varDecl, Diagnostics::definitionOfExternDeclMismatchesOriginalDefinition, varDecl);
+ }
+ }
+ }
+
void SemanticsDeclHeaderVisitor::checkVarDeclCommon(VarDeclBase* varDecl)
{
// A variable that didn't have an explicit type written must
@@ -1136,6 +1228,16 @@ namespace Slang
getSink()->diagnose(varDecl, Diagnostics::valueRequirementMustBeCompileTimeConst);
}
}
+
+ // Check modifiers that can't be checked earlier during modifier checking stage.
+ if (auto derivativeMemberAttr = varDecl->findModifier<DerivativeMemberAttribute>())
+ {
+ checkDerivativeMemberAttribute(varDecl, derivativeMemberAttr);
+ }
+ if (auto extensionExternAttr = varDecl->findModifier<ExtensionExternVarModifier>())
+ {
+ checkExtensionExternVarAttribute(varDecl, extensionExternAttr);
+ }
}
void SemanticsDeclHeaderVisitor::visitStructDecl(StructDecl* structDecl)
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 745532c27..29b44e726 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -755,7 +755,7 @@ namespace Slang
}
else if (diffTypeLookupResult.isOverloaded())
{
- SLANG_UNIMPLEMENTED_X("Ambiguous differential type declarations not supported");
+ getSink()->diagnose(declRefType->declRef, Diagnostics::ambiguousReference, getName("Differential"));
}
else
{
@@ -774,7 +774,7 @@ namespace Slang
}
}
- return nullptr;
+ return m_astBuilder->getErrorType();
}
void SemanticsVisitor::maybeRegisterDifferentiableType(ASTBuilder* builder, Type* type)
@@ -813,103 +813,6 @@ namespace Slang
}
}
- Expr* SemanticsVisitor::maybeMakeDifferentialExpr(Expr* checkedTerm)
- {
- // Check that member lookups on differentiable types have appropriate differential
- // getters and setters.
- if (auto declRefExpr = as<DeclRefExpr>(checkedTerm))
- {
-
- // Check if we have a parent container. If yes, then checkedTerm is
- // referencing a member of this parent.
- //
- auto parentType = DeclRefType::create(getASTBuilder(), declRefExpr->declRef.getParent());
-
- // Check if we have an aggregate (i.e. struct-like) type.
- // Ignore interfaces and the case when the term refers to a function
- //
- if (parentType->declRef.as<AggTypeDeclBase>() &&
- !parentType->declRef.as<InterfaceDecl>() &&
- !declRefExpr->declRef.as<CallableDecl>())
- {
- // Check if the parent container type is differentiable.
- if (auto parentDiffWitness = as<SubtypeWitness>(
- tryGetInterfaceConformanceWitness(
- parentType, getASTBuilder()->getDifferentiableInterface())))
- {
- // If yes, the member in checkedTerm should have a differential getter and setter.
- // Otherwise, <ERROR>
- //
- auto diffExpr = m_astBuilder->create<DifferentiableDeclRefExpr>();
- diffExpr->type = checkedTerm->type;
- diffExpr->inner = checkedTerm;
-
- {
- auto getterName = getName("__getDifferentialFor_" + declRefExpr->name->text);
- auto getterResult = lookUpMember(
- getASTBuilder(),
- this,
- getterName,
- parentType,
- Slang::LookupMask::Function,
- Slang::LookupOptions::None);
-
- if (!getterResult.isValid())
- {
- // Do nothing.. we assume that this field cannot be differentiated.
- // Could this be confusing from a user perspective?
- }
- else if (getterResult.isOverloaded())
- {
- // Diagnose ambiguous getter.
- SLANG_UNIMPLEMENTED_X("Ambiguous differential getters not supported");
- }
- else
- {
- auto getterRefExpr = ConstructLookupResultExpr(
- getterResult.item,
- declRefExpr,
- getterResult.item.declRef.getLoc(),
- nullptr);
-
- // Check that the type is what we expect.
- // We're going to do this in a very crude way for now.
- // Ideally, we want to use the overload resolution and type
- // coercion logic in ResolveInvoke()
- //
-
- auto diffType = _getDifferential(m_astBuilder, checkedTerm->type.type);
- auto diffParentType = _getDifferential(m_astBuilder, parentType);
-
- auto ptrDiffType = m_astBuilder->getPtrType(diffType);
- auto inoutContainerDiffType = m_astBuilder->getInOutType(diffParentType);
-
- auto funcType = as<FuncType>(getterRefExpr->type);
-
- if (!ptrDiffType->equals(funcType->getResultType()))
- {
- getSink()->diagnose(getterRefExpr, Diagnostics::typeMismatch,
- ptrDiffType, funcType->getResultType());
- }
-
- if (!inoutContainerDiffType->equals(funcType->getParamType(0)))
- {
- getSink()->diagnose(getterRefExpr, Diagnostics::typeMismatch,
- inoutContainerDiffType, funcType->getParamType(0));
- }
-
- diffExpr->getterExpr = getterRefExpr;
- }
- }
-
- return diffExpr;
- }
- }
- }
-
- return checkedTerm;
- }
-
Expr* SemanticsVisitor::CheckTerm(Expr* term)
{
auto checkedTerm = _CheckTerm(term);
@@ -920,11 +823,6 @@ namespace Slang
this->m_parentFunc->findModifier<JVPDerivativeModifier>())
{
maybeRegisterDifferentiableType(getASTBuilder(), checkedTerm->type.type);
-
- if (auto declRefExpr = as<DeclRefExpr>(checkedTerm))
- {
- checkedTerm = maybeMakeDifferentialExpr(checkedTerm);
- }
}
return checkedTerm;
@@ -1888,14 +1786,6 @@ namespace Slang
return expr;
}
- Expr* SemanticsExprVisitor::visitDifferentiableDeclRefExpr(DifferentiableDeclRefExpr* expr)
- {
- auto checkedInnerTerm = CheckTerm(expr->inner);
- expr->type = checkedInnerTerm->type;
- return expr;
- }
-
-
Type* SemanticsVisitor::_toDifferentialParamType(ASTBuilder* builder, Type* primalType)
{
// Check for type modifiers like 'out' and 'inout'. We need to differentiate the
@@ -2729,31 +2619,32 @@ namespace Slang
// we can return an overloaded result.
if (auto overloadedExpr = as<OverloadedExpr>(baseExpr))
{
- if (overloadedExpr->base)
+ // If a member (dynamic or static) lookup result contains both the actual definition
+ // and the interface definition obtained from inheritance, we want to filter out
+ // the interface definitions.
+ LookupResult filteredLookupResult;
+ for (auto lookupResult : overloadedExpr->lookupResult2)
{
- // If a member (dynamic or static) lookup result contains both the actual definition
- // and the interface definition obtained from inheritance, we want to filter out
- // the interface definitions.
- LookupResult filteredLookupResult;
- for (auto lookupResult : overloadedExpr->lookupResult2)
+ bool shouldRemove = false;
+ if (lookupResult.declRef.getParent().as<InterfaceDecl>())
{
- bool shouldRemove = false;
- if (lookupResult.declRef.getParent().as<InterfaceDecl>())
- shouldRemove = true;
- if (!shouldRemove)
- {
- filteredLookupResult.items.add(lookupResult);
- }
+ shouldRemove = true;
+ }
+ if (lookupResult.declRef.getDecl()->hasModifier<ExtensionExternVarModifier>())
+ shouldRemove = true;
+ if (!shouldRemove)
+ {
+ filteredLookupResult.items.add(lookupResult);
}
- if (filteredLookupResult.items.getCount() == 1)
- filteredLookupResult.item = filteredLookupResult.items.getFirst();
- baseExpr = createLookupResultExpr(
- overloadedExpr->name,
- filteredLookupResult,
- overloadedExpr->base,
- overloadedExpr->loc,
- overloadedExpr);
}
+ if (filteredLookupResult.items.getCount() == 1)
+ filteredLookupResult.item = filteredLookupResult.items.getFirst();
+ baseExpr = createLookupResultExpr(
+ overloadedExpr->name,
+ filteredLookupResult,
+ overloadedExpr->base,
+ overloadedExpr->loc,
+ overloadedExpr);
// TODO: handle other cases of OverloadedExpr that need filtering.
}
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index 5c1c20e3a..0877f2d6e 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -209,32 +209,10 @@ namespace Slang
Substitutions* subst = nullptr;
};
- struct LookupRequestKey
- {
- NodeBase* base;
- Name* name;
- LookupOptions options;
- LookupMask mask;
- bool operator==(const LookupRequestKey& other) const
- {
- return base == other.base && name == other.name && options == other.options && mask == other.mask;
- }
- HashCode getHashCode() const
- {
- Hasher hasher;
- hasher.hashValue(base);
- hasher.hashValue(name);
- hasher.hashValue(options);
- hasher.hashValue(mask);
- return hasher.getResult();
- }
- };
-
struct TypeCheckingCache
{
Dictionary<OperatorOverloadCacheKey, OverloadCandidate> resolvedOperatorOverloadCache;
Dictionary<BasicTypeKeyPair, ConversionCost> conversionCostCache;
- Dictionary<LookupRequestKey, LookupResult> lookupCache;
};
struct DifferentiableTypeSemanticContext
@@ -305,11 +283,6 @@ namespace Slang
bool m_isTypeDictionaryRequired = false;
};
- /// Give a cache and a name, will remove all entries associated with a name
- /// Might be useful/necessary if a new name is introduced
- void removeLookupForName(TypeCheckingCache* cache, Name* name);
-
-
/// Shared state for a semantics-checking session.
struct SharedSemanticsContext
{
@@ -525,6 +498,13 @@ namespace Slang
return result;
}
+ SemanticsContext allowStaticReferenceToNonStaticMember()
+ {
+ SemanticsContext result(*this);
+ result.m_allowStaticReferenceToNonStaticMember = true;
+ return result;
+ }
+
private:
SharedSemanticsContext* m_shared = nullptr;
@@ -545,6 +525,10 @@ namespace Slang
/// The type of a try clause (if any) enclosing current expr.
TryClauseType m_enclosingTryClauseType = TryClauseType::None;
+ /// Whether an expr referencing to a non-static member in static style (e.g. `Type.member`)
+ /// is considered valid in the current context.
+ bool m_allowStaticReferenceToNonStaticMember = false;
+
ASTBuilder* m_astBuilder = nullptr;
};
@@ -819,11 +803,6 @@ namespace Slang
// Check and register a type if it is differentiable.
void maybeRegisterDifferentiableType(ASTBuilder* builder, Type* type);
- // Check if a term is referencing a member, and add a decoration to it's
- // differential getter function, if one exists.
- //
- Expr* maybeMakeDifferentialExpr(Expr* checkedTerm);
-
// Construct the differential for 'type', if it exists.
Type* _getDifferential(ASTBuilder* builder, Type* type);
@@ -1018,7 +997,7 @@ namespace Slang
bool getAttributeTargetSyntaxClasses(SyntaxClass<NodeBase> & cls, uint32_t typeFlags);
- bool validateAttribute(Attribute* attr, AttributeDecl* attribClassDecl);
+ bool validateAttribute(Attribute* attr, AttributeDecl* attribClassDecl, ModifiableSyntaxNode* attrTarget);
AttributeBase* checkAttribute(
UncheckedAttribute* uncheckedAttr,
@@ -1924,8 +1903,6 @@ namespace Slang
Expr* visitVarExpr(VarExpr *expr);
- Expr* visitDifferentiableDeclRefExpr(DifferentiableDeclRefExpr *expr);
-
Expr* visitTypeCastExpr(TypeCastExpr * expr);
Expr* visitTryExpr(TryExpr* expr);
diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp
index a2b411c22..f977721dd 100644
--- a/source/slang/slang-check-modifier.cpp
+++ b/source/slang/slang-check-modifier.cpp
@@ -228,17 +228,6 @@ namespace Slang
SLANG_ASSERT(!parentDecl->isMemberDictionaryValid());
- // TODO(JS): A bit of a work around(!)
- //
- // To get to this point we must have already have performed a lookup for attributeName,
- // and it failed. That lookup used the TypeCheckingCache, and
- // so we know there is a cache entry that will be *wrong*, now we have created and
- // added the AttributeDecl with the attributeName.
- //
- // To work around, we remove all cached lookups around the name, such that when a subsequent
- // lookup is made, the cache will not return the old (wrong) result.
- removeLookupForName(getLinkage()->getTypeCheckingCache(), attributeName);
-
// Finally, we perform any required semantic checks on
// the newly constructed attribute decl.
//
@@ -301,7 +290,7 @@ namespace Slang
return false;
}
- bool SemanticsVisitor::validateAttribute(Attribute* attr, AttributeDecl* attribClassDecl)
+ bool SemanticsVisitor::validateAttribute(Attribute* attr, AttributeDecl* attribClassDecl, ModifiableSyntaxNode* attrTarget)
{
if(auto numThreadsAttr = as<NumThreadsAttribute>(attr))
{
@@ -504,7 +493,6 @@ namespace Slang
}
else if (auto userDefAttr = as<UserDefinedAttribute>(attr))
{
-
// check arguments against attribute parameters defined in attribClassDecl
Index paramIndex = 0;
auto params = attribClassDecl->getMembersOfType<ParamDecl>();
@@ -659,6 +647,15 @@ namespace Slang
return false;
}
}
+ else if (auto derivativeMemberAttr = as<DerivativeMemberAttribute>(attr))
+ {
+ auto varDecl = as<VarDeclBase>(attrTarget);
+ if (!varDecl)
+ {
+ getSink()->diagnose(attr, Diagnostics::attributeNotApplicable, attr->getKeywordName());
+ return false;
+ }
+ }
else
{
if(attr->args.getCount() == 0)
@@ -784,7 +781,7 @@ namespace Slang
}
// Now apply type-specific validation to the attribute.
- if(!validateAttribute(attr, attrDecl))
+ if(!validateAttribute(attr, attrDecl, attrTarget))
{
return uncheckedAttr;
}
@@ -817,7 +814,40 @@ namespace Slang
CompletionSuggestions::ScopeKind::HLSLSemantics;
}
}
-
+
+ if (auto externModifier = as<ExternModifier>(m))
+ {
+ if (auto varDecl = as<VarDeclBase>(syntaxNode))
+ {
+ if (auto parentExtension = as<ExtensionDecl>(varDecl->parentDecl))
+ {
+ auto originalMemberLookup = lookUpMember(m_astBuilder, this, varDecl->getName(), parentExtension->targetType);
+ LookupResult filteredResult;
+ for (auto item : originalMemberLookup.items)
+ {
+ if (item.declRef.getDecl() != varDecl)
+ AddToLookupResult(filteredResult, item);
+ }
+ if (filteredResult.isValid() && !filteredResult.isOverloaded())
+ {
+ auto extensionExternMemberModifier = m_astBuilder->create<ExtensionExternVarModifier>();
+ extensionExternMemberModifier->originalDecl = filteredResult.item.declRef;
+ return extensionExternMemberModifier;
+ }
+ else if (filteredResult.isOverloaded())
+ {
+ getSink()->diagnose(varDecl, Diagnostics::ambiguousOriginalDefintionOfExternDecl, varDecl);
+ }
+ else
+ {
+ getSink()->diagnose(varDecl, Diagnostics::missingOriginalDefintionOfExternDecl, varDecl);
+ }
+ }
+ // The next part of the check is to make sure the type defined here is consistent with the original definition.
+ // Since we haven't checked the type of this decl yet, we defer that until we have fully checked decl.
+ // See SemanticsDeclHeaderVisitor::checkExtensionExternVarAttribute.
+ }
+ }
// Default behavior is to leave things as they are,
// and assume that modifiers are mostly already checked.
//
diff --git a/source/slang/slang-check.cpp b/source/slang/slang-check.cpp
index 8c6cddbfe..bcc74a6d0 100644
--- a/source/slang/slang-check.cpp
+++ b/source/slang/slang-check.cpp
@@ -210,24 +210,4 @@ namespace Slang
throw;
}
}
-
- void removeLookupForName(TypeCheckingCache* cache, Name* name)
- {
- auto& lookupCache = cache->lookupCache;
-
- List<LookupRequestKey> keys;
-
- for (const auto& pairs : lookupCache)
- {
- const auto& key = pairs.Key;
- if (key.name == name)
- {
- keys.add(key);
- }
- }
- for (auto& key : keys)
- {
- lookupCache.Remove(key);
- }
- }
}
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index d7e56309a..6e6a6f5e5 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -329,19 +329,26 @@ DIAGNOSTIC(31123, Error, invalidGUID, "'$0' is not a valid GUID")
DIAGNOSTIC(31124, Error, structCannotImplementComInterface, "a struct type cannot implement a [COM] interface")
DIAGNOSTIC(31124, Error, interfaceInheritingComMustBeCom, "an interface type that inherits from a [COM] interface must itself be a [COM] interface")
+DIAGNOSTIC(31130, Error, derivativeMemberAttributeMustNameAMemberInExpectedDifferentialType, "[DerivativeMember] must reference to a member in the associated differential type '$0'.")
+DIAGNOSTIC(31131, Error, invalidUseOfDerivativeMemberAttributeParentTypeIsNotDifferentiable, "invalid use of [DerivativeMember], parent type is not differentiable.")
+DIAGNOSTIC(31132, Error, derivativeMemberAttributeCanOnlyBeUsedOnMembers, "[DerivativeMember] is allowed on members only.")
+
+DIAGNOSTIC(31140, Error, typeOfExternDeclMismatchesOriginalDefinition, "type of `extern` decl '$0' differs from its original definition. expected '$1'.")
+DIAGNOSTIC(31141, Error, definitionOfExternDeclMismatchesOriginalDefinition, "`extern` decl '$0' is not consistent with its original definition.")
+DIAGNOSTIC(31142, Error, ambiguousOriginalDefintionOfExternDecl, "`extern` decl '$0' has ambiguous original definitions.")
+DIAGNOSTIC(31143, Error, missingOriginalDefintionOfExternDecl, "no original definition found for `extern` decl '$0'.")
// Enums
DIAGNOSTIC(32000, Error, invalidEnumTagType, "invalid tag type for 'enum': '$0'")
DIAGNOSTIC(32003, Error, unexpectedEnumTagExpr, "unexpected form for 'enum' tag value expression")
-
-
// 303xx: interfaces and associated types
DIAGNOSTIC(30300, Error, assocTypeInInterfaceOnly, "'associatedtype' can only be defined in an 'interface'.")
DIAGNOSTIC(30301, Error, globalGenParamInGlobalScopeOnly, "'type_param' can only be defined global scope.")
DIAGNOSTIC(30302, Error, staticConstRequirementMustBeIntOrBool, "'static const' requirement can only have int or bool type.")
DIAGNOSTIC(30303, Error, valueRequirementMustBeCompileTimeConst, "requirement in the form of a simple value must be declared as 'static const'.")
+DIAGNOSTIC(30310, Error, typeIsNotDifferentiable, "type '$0' is not differentiable.")
// Interop
DIAGNOSTIC(30400, Error, cannotDefinePtrTypeToManagedResource, "pointer to a managed resource is invalid, use `NativeRef<T>` instead")
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index 1ea54475e..bab33e79d 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -384,6 +384,8 @@ Result linkAndOptimizeIR(
// 3. Fill in higher-order invocations with the generated functions.
processDerivativeCalls(irModule);
+ stripAutoDiffDecorations(irModule);
+
dumpIRIfEnabled(codeGenContext, irModule, "AFTER-AUTODIFF");
validateIRModuleIfEnabled(codeGenContext, irModule);
diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp
index 843428c01..b97556ab1 100644
--- a/source/slang/slang-ir-diff-jvp.cpp
+++ b/source/slang/slang-ir-diff-jvp.cpp
@@ -115,7 +115,7 @@ struct DifferentiableTypeConformanceContext
IRInst* lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key)
{
- if (auto conformance = lookUpConformanceForType(builder, origType))
+ if (auto conformance = lookUpConformanceForType(builder, origType))
{
if (auto witnessTable = as<IRWitnessTable>(conformance))
{
@@ -144,6 +144,14 @@ struct DifferentiableTypeConformanceContext
//
IRInst* getDifferentialForType(IRBuilder* builder, IRType* origType)
{
+ switch (origType->getOp())
+ {
+ case kIROp_FloatType:
+ case kIROp_HalfType:
+ case kIROp_DoubleType:
+ case kIROp_VectorType:
+ return origType;
+ }
return lookUpInterfaceMethod(builder, origType, differentialAssocTypeStructKey);
}
@@ -1083,8 +1091,7 @@ struct JVPTranscriber
// in the current transcription context.
//
InstPair transcribeCall(IRBuilder* builder, IRCall* origCall)
- {
-
+ {
if (as<IRFunc>(origCall->getCallee()))
{
auto origCallee = origCall->getCallee();
@@ -1094,12 +1101,28 @@ struct JVPTranscriber
//
auto primalCallee = origCallee;
- // TODO: If inner is not differentiable, treat as non-differentiable call.
- // Build the differential callee
- IRInst* diffCall = builder->emitJVPDifferentiateInst(
- differentiateFunctionType(builder, as<IRFuncType>(primalCallee->getFullType())),
- primalCallee);
-
+ IRInst* diffCallee = nullptr;
+
+ if (auto derivativeReferenceDecor = primalCallee->findDecoration<IRJVPDerivativeReferenceDecoration>())
+ {
+ // If the user has already provided an differentiated implementation, use that.
+ diffCallee = derivativeReferenceDecor->getJVPFunc();
+ }
+ else if (primalCallee->findDecoration<IRJVPDerivativeMarkerDecoration>())
+ {
+ // If the function is marked for auto-diff, push a `differentiate` inst for a follow up pass
+ // to generate the implementation.
+ diffCallee = builder->emitJVPDifferentiateInst(
+ differentiateFunctionType(builder, as<IRFuncType>(primalCallee->getFullType())),
+ primalCallee);
+ }
+ else
+ {
+ // The callee is non differentiable, just return primal value with null diff value.
+ IRInst* primalCall = cloneInst(&cloneEnv, builder, origCall);
+ return InstPair(primalCall, nullptr);
+ }
+
List<IRInst*> args;
// Go over the parameter list and create pairs for each input (if required)
for (UIndex ii = 0; ii < origCall->getArgCount(); ii++)
@@ -1109,18 +1132,16 @@ struct JVPTranscriber
SLANG_ASSERT(primalArg);
auto primalType = primalArg->getDataType();
+ auto diffArg = findOrTranscribeDiffInst(builder, origArg);
+
+ if (!diffArg)
+ diffArg = getDifferentialZeroOfType(builder, primalType);
+
if (auto pairType = tryGetDiffPairType(builder, primalType))
{
- auto diffArg = findOrTranscribeDiffInst(builder, origArg);
-
- if (!diffArg)
- diffArg = getDifferentialZeroOfType(builder, primalType);
-
// If a pair type can be formed, this must be non-null.
SLANG_RELEASE_ASSERT(diffArg);
-
auto diffPair = builder->emitMakeDifferentialPair(pairType, primalArg, diffArg);
-
args.add(diffPair);
}
else
@@ -1130,17 +1151,19 @@ struct JVPTranscriber
}
}
- auto diffReturnType = tryGetDiffPairType(builder, origCall->getFullType());
+ IRType* diffReturnType = nullptr;
+ diffReturnType = tryGetDiffPairType(builder, origCall->getFullType());
SLANG_ASSERT(diffReturnType);
auto callInst = builder->emitCallInst(
diffReturnType,
- diffCall,
+ diffCallee,
args);
+
+ IRInst* primalResultValue = pairBuilder->emitPrimalFieldAccess(builder, callInst);
+ IRInst* diffResultValue = pairBuilder->emitDiffFieldAccess(builder, callInst);
- return InstPair(
- pairBuilder->emitPrimalFieldAccess(builder, callInst),
- pairBuilder->emitDiffFieldAccess(builder, callInst));
+ return InstPair(primalResultValue, diffResultValue);
}
else if(as<IRSpecialize>(origCall->getCallee()) ||
as<IRLookupWitnessMethod>(origCall->getCallee()))
@@ -1396,89 +1419,45 @@ struct JVPTranscriber
return InstPair(diffBlock, diffBlock);
}
- InstPair transcribeFieldExtract(IRBuilder* builder, IRFieldExtract* origExtract)
+ InstPair transcribeFieldExtract(IRBuilder* builder, IRInst* originalInst)
{
- IRInst* origBase = origExtract->getBase();
+ SLANG_ASSERT(as<IRFieldExtract>(originalInst) || as<IRFieldAddress>(originalInst));
+
+ IRInst* origBase = originalInst->getOperand(0);
auto primalBase = findOrTranscribePrimalInst(builder, origBase);
- auto diffBase = findOrTranscribeDiffInst(builder, origBase);
+ auto field = originalInst->getOperand(1);
+ auto derivativeRefDecor = field->findDecoration<IRJVPDerivativeMemberReferenceDecoration>();
+ auto primalType = (IRType*)lookupPrimalInst(originalInst->getDataType(), originalInst->getDataType());
- auto primalExtractType = (IRType*)lookupPrimalInst(origExtract->getDataType(), origExtract->getDataType());
-
- IRInst* primalExtract = builder->emitFieldExtract(primalExtractType, primalBase, origExtract->getField());
- IRInst* diffExtract = nullptr;
+ IRInst* primalOperands[] = { primalBase, field };
+ IRInst* primalFieldExtract = builder->emitIntrinsicInst(
+ primalType,
+ originalInst->getOp(),
+ 2,
+ primalOperands);
- if (auto diffExtractType = differentiateType(builder, primalExtractType))
+ if (!derivativeRefDecor)
{
- // Check if we have a getter.
- if (auto getterDecoration = origExtract->findDecoration<IRDifferentialGetterDecoration>())
- {
-
- IRInst* getterFunc = getterDecoration->getGetterFunc();
-
- // Must be a method with a single parameter.
- SLANG_ASSERT(as<IRFuncType>(getterFunc->getDataType())->getParamCount() == 1);
-
- // Our getter func accepts a _pointer_ to the target type
- // So we have to create a variable and store our type into memory
- // here. This will eventually get optimized out in later passes.
- //
- auto diffTempVar = builder->emitVar(
- diffBase->getDataType());
-
- builder->emitStore(diffTempVar, diffBase);
-
- List<IRInst*> args;
- args.add(diffTempVar);
-
- // Emit a call to the getter. The getter will return a reference type.
- // We need to load from this to go to a non-ptr 'solid' type.
- //
- auto diffGetterCall = builder->emitCallInst(
- as<IRFuncType>(getterFunc->getDataType())->getResultType(),
- getterFunc,
- args);
-
- diffExtract = builder->emitLoad(diffGetterCall);
- }
+ return InstPair(primalFieldExtract, nullptr);
}
- return InstPair(primalExtract, diffExtract);
- }
-
- InstPair transcribeFieldAddress(IRBuilder* builder, IRFieldAddress* origAddress)
- {
- IRInst* origBase = origAddress->getBase();
- auto primalBase = findOrTranscribePrimalInst(builder, origBase);
- auto diffBase = findOrTranscribeDiffInst(builder, origBase);
-
- auto primalAddressType = (IRType*)lookupPrimalInst(origAddress->getDataType(), origAddress->getDataType());
+ IRInst* diffFieldExtract = nullptr;
- IRInst* primalAddress = builder->emitFieldAddress(primalAddressType, primalBase, origAddress->getField());
- IRInst* diffAddress = nullptr;
-
- if (auto diffAddressType = differentiateType(builder, primalAddressType))
+ if (auto diffType = differentiateType(builder, primalType))
{
- // If we have a getter associated with this field, we want to use that.
- if (auto getterDecoration = origAddress->findDecoration<IRDifferentialGetterDecoration>())
+ if (auto diffBase = findOrTranscribeDiffInst(builder, origBase))
{
- auto getterFunc = getterDecoration->getGetterFunc();
-
- // Add the base differential inst as the argument.
- List<IRInst*> args;
- args.add(diffBase);
-
- diffAddress = builder->emitCallInst(
- as<IRFuncType>(getterFunc->getDataType())->getResultType(),
- getterFunc,
- args);
+ IRInst* diffOperands[] = { diffBase, derivativeRefDecor->getDerivativeMemberStructKey() };
+ diffFieldExtract = builder->emitIntrinsicInst(
+ diffType,
+ originalInst->getOp(),
+ 2,
+ diffOperands);
}
-
}
-
- return InstPair(primalAddress, diffAddress);
+ return InstPair(primalFieldExtract, diffFieldExtract);
}
-
InstPair transcribeGetElement(IRBuilder* builder, IRInst* origGetElementPtr)
{
SLANG_ASSERT(as<IRGetElement>(origGetElementPtr) || as<IRGetElementPtr>(origGetElementPtr));
@@ -1514,7 +1493,6 @@ struct JVPTranscriber
return InstPair(primalGetElementPtr, diffGetElementPtr);
}
-
InstPair transcribeLoop(IRBuilder* builder, IRLoop* origLoop)
{
// The loop comes with three blocks.. we just need to transcribe each one
@@ -1640,9 +1618,13 @@ struct JVPTranscriber
as<IRFuncType>(origFunc->getFullType()));
diffFunc->setFullType(diffFuncType);
- // TODO(sai): Replace naming scheme
- // if (auto jvpName = this->getJVPFuncName(builder, primalFn))
- // builder->addNameHintDecoration(diffFunc, jvpName);
+ if (auto nameHint = origFunc->findDecoration<IRNameHintDecoration>())
+ {
+ auto originalName = nameHint->getName();
+ StringBuilder newNameSb;
+ newNameSb << "s_jvp_" << originalName;
+ builder->addNameHintDecoration(diffFunc, newNameSb.getUnownedSlice());
+ }
// Transcribe children from origFunc into diffFunc
builder->setInsertInto(diffFunc);
@@ -1719,9 +1701,18 @@ struct JVPTranscriber
{
mapPrimalInst(origInst, pair.primal);
mapDifferentialInst(origInst, pair.differential);
+ if (pair.differential)
+ {
+ // Generate name hint for the inst.
+ if (auto primalNameHint = primalInst->findDecoration<IRNameHintDecoration>())
+ {
+ StringBuilder sb;
+ sb << "s_diff_" << primalNameHint->getName();
+ builder->addNameHintDecoration(pair.differential, sb.getUnownedSlice());
+ }
+ }
return pair.differential;
}
-
instsInProgress.Remove(origInst);
getSink()->diagnose(origInst->sourceLoc,
@@ -1789,16 +1780,14 @@ struct JVPTranscriber
getSink()->diagnose(origInst->sourceLoc,
Diagnostics::unexpected,
"should not be attempting to differentiate anything specialized here.");
+ return InstPair(nullptr, nullptr);
case kIROp_lookup_interface_method:
return transcibeLookupInterfaceMethod(builder, as<IRLookupWitnessMethod>(origInst));
case kIROp_FieldExtract:
- return transcribeFieldExtract(builder, as<IRFieldExtract>(origInst));
-
case kIROp_FieldAddress:
- return transcribeFieldAddress(builder, as<IRFieldAddress>(origInst));
-
+ return transcribeFieldExtract(builder, origInst);
case kIROp_getElement:
case kIROp_getElementPtr:
return transcribeGetElement(builder, origInst);
@@ -1942,11 +1931,6 @@ struct JVPDerivativeContext
// Temporary fix: Move generated types, if any, to before their use locations.
(&pairBuilderStorage)->relocateNewTypes(builder);
- // Remove all kIROp_DifferentiableTypeDictionary instructions and
- // kIROp_DifferentialGetterDecoration decorations
- //
- modified |= stripDiffTypeInformation(builder, module->getModuleInst());
-
return modified;
}
@@ -1954,7 +1938,6 @@ struct JVPDerivativeContext
{
if(auto jvpDefinition = primalFunction->findDecoration<IRJVPDerivativeReferenceDecoration>())
return jvpDefinition->getJVPFunc();
-
return nullptr;
}
@@ -2166,7 +2149,7 @@ struct JVPDerivativeContext
return modified;
}
- bool stripDiffTypeInformation(IRBuilder* builder, IRInst* parent)
+ bool stripDiffTypeInformation(IRInst* parent)
{
bool modified = false;
@@ -2175,22 +2158,18 @@ struct JVPDerivativeContext
{
auto nextChild = child->getNextInst();
- if (child->getOp() == kIROp_DifferentiableTypeDictionary)
+ switch (child->getOp())
{
+ case kIROp_DifferentiableTypeDictionary:
child->removeAndDeallocate();
child = nextChild;
modified = true;
continue;
}
- if (auto getterDecoration = child->findDecoration<IRDifferentialGetterDecoration>())
- {
- getterDecoration->removeAndDeallocate();
- }
-
if (child->getFirstChild() != nullptr)
{
- modified |= stripDiffTypeInformation(builder, child);
+ modified |= stripDiffTypeInformation(child);
}
child = nextChild;
@@ -2311,8 +2290,30 @@ bool processJVPDerivativeMarkers(
eliminateDeadCode(module, options);
JVPDerivativeContext context(module, sink);
+ bool changed = context.processModule();
+ changed |= context.stripDiffTypeInformation(module->getModuleInst());
+ return changed;
+}
- return context.processModule();
+void stripAutoDiffDecorations(IRModule* module)
+{
+ for (auto inst : module->getGlobalInsts())
+ {
+ for (auto decor = inst->getFirstDecoration(); decor; )
+ {
+ auto next = decor->getNextDecoration();
+ switch (decor->getOp())
+ {
+ case kIROp_JVPDerivativeReferenceDecoration:
+ case kIROp_JVPDerivativeMemberReferenceDecoration:
+ decor->removeAndDeallocate();
+ break;
+ default:
+ break;
+ }
+ decor = next;
+ }
+ }
}
}
diff --git a/source/slang/slang-ir-diff-jvp.h b/source/slang/slang-ir-diff-jvp.h
index 8ae6e949a..8ab4e0e8f 100644
--- a/source/slang/slang-ir-diff-jvp.h
+++ b/source/slang/slang-ir-diff-jvp.h
@@ -18,4 +18,5 @@ namespace Slang
DiagnosticSink* sink,
IRJVPDerivativePassOptions const& options = IRJVPDerivativePassOptions());
+ void stripAutoDiffDecorations(IRModule* module);
}
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index f91fc9cda..c59286116 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -707,8 +707,8 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0)
INST(JVPDerivativeReferenceDecoration, jvpFnReference, 1, 0)
/// Used by the auto-diff pass to hold a reference to a
- /// differential getter associated with this expression.
- INST(DifferentialGetterDecoration, diffGetter, 1, 0)
+ /// differential member of a type in its associated differential type.
+ INST(JVPDerivativeMemberReferenceDecoration, derivativeMemberDecoration, 1, 0)
/// Marks a class type as a COM interface implementation, which enables
/// the witness table to be easily picked up by emit.
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 33a2fbfb0..5a9c14038 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -546,6 +546,15 @@ struct IRSequentialIDDecoration : IRDecoration
IRIntegerValue getSequentialID() { return getSequentialIDOperand()->getValue(); }
};
+struct IRJVPDerivativeMarkerDecoration : IRDecoration
+{
+ enum
+ {
+ kOp = kIROp_JVPDerivativeMarkerDecoration
+ };
+ IR_LEAF_ISA(JVPDerivativeMarkerDecoration)
+};
+
struct IRJVPDerivativeReferenceDecoration : IRDecoration
{
enum
@@ -557,15 +566,15 @@ struct IRJVPDerivativeReferenceDecoration : IRDecoration
IRInst* getJVPFunc() { return getOperand(0); }
};
-struct IRDifferentialGetterDecoration : IRDecoration
+struct IRJVPDerivativeMemberReferenceDecoration : IRDecoration
{
enum
{
- kOp = kIROp_DifferentialGetterDecoration
+ kOp = kIROp_JVPDerivativeMemberReferenceDecoration
};
- IR_LEAF_ISA(DifferentialGetterDecoration)
+ IR_LEAF_ISA(JVPDerivativeMemberReferenceDecoration)
- IRInst* getGetterFunc() { return getOperand(0); }
+ IRInst* getDerivativeMemberStructKey() { return getOperand(0); }
};
// An instruction that replaces the function symbol
@@ -3192,6 +3201,11 @@ public:
addDecoration(value, kIROp_ExternCppDecoration, getStringValue(mangledName));
}
+ void addForceInlineDecoration(IRInst* value)
+ {
+ addDecoration(value, kIROp_ForceInlineDecoration);
+ }
+
void addJVPDerivativeMarkerDecoration(IRInst* value)
{
addDecoration(value, kIROp_JVPDerivativeMarkerDecoration);
@@ -3202,11 +3216,6 @@ public:
addDecoration(value, kIROp_JVPDerivativeReferenceDecoration, jvpFn);
}
- void addDifferentialGetterDecoration(IRInst* value, IRInst* getterFn)
- {
- addDecoration(value, kIROp_DifferentialGetterDecoration, getterFn);
- }
-
void addCOMWitnessDecoration(IRInst* value, IRInst* witnessTable)
{
addDecoration(value, kIROp_COMWitnessDecoration, &witnessTable, 1);
diff --git a/source/slang/slang-lookup.cpp b/source/slang/slang-lookup.cpp
index cddf3d7ce..c574be4ea 100644
--- a/source/slang/slang-lookup.cpp
+++ b/source/slang/slang-lookup.cpp
@@ -89,6 +89,18 @@ void buildMemberDictionary(ContainerDecl* decl)
bool DeclPassesLookupMask(Decl* decl, LookupMask mask)
{
+ // Always exclude extern members from lookup result.
+ if (decl->hasModifier<ExtensionExternVarModifier>())
+ {
+ return false;
+ }
+ else if (decl->hasModifier<ExternModifier>())
+ {
+ if (as<ExtensionDecl>(decl->parentDecl))
+ {
+ return false;
+ }
+ }
// type declarations
if(auto aggTypeDecl = as<AggTypeDecl>(decl))
{
@@ -108,7 +120,7 @@ bool DeclPassesLookupMask(Decl* decl, LookupMask mask)
{
return (int(mask) & int(LookupMask::Attribute)) != 0;
}
-
+
// default behavior is to assume a value declaration
// (no overloading allowed)
@@ -942,7 +954,7 @@ static void _lookUpInScopes(
// The implicit `this`/`This` for a function-like declaration
// depends on modifiers attached to the declaration.
//
- if (funcDeclRef.getDecl()->hasModifier<HLSLStaticModifier>())
+ if (isEffectivelyStatic(funcDeclRef.getDecl()))
{
// A `static` method only has access to an implicit `This`,
// and does not have a `this` expression available.
@@ -1002,26 +1014,8 @@ LookupResult lookUp(
LookupMask mask)
{
LookupResult result;
- LookupRequestKey key;
- TypeCheckingCache* typeCheckingCache = nullptr;
- if (semantics)
- {
- typeCheckingCache = semantics->getLinkage()->getTypeCheckingCache();
- key.base = scope;
- key.name = name;
- key.options = LookupOptions::None;
- key.mask = mask;
- if (typeCheckingCache->lookupCache.TryGetValue(key, result))
- {
- return result;
- }
- }
LookupRequest request = initLookupRequest(semantics, name, mask, LookupOptions::None, scope);
_lookUpInScopes(astBuilder, name, request, result);
- if (typeCheckingCache)
- {
- typeCheckingCache->lookupCache[key] = result;
- }
return result;
}
@@ -1033,20 +1027,9 @@ LookupResult lookUpMember(
LookupMask mask,
LookupOptions options)
{
- TypeCheckingCache* typeCheckingCache = semantics->getLinkage()->getTypeCheckingCache();
- LookupRequestKey key;
- key.base = type;
- key.name = name;
- key.options = options;
- key.mask = mask;
LookupResult result;
- if (typeCheckingCache->lookupCache.TryGetValue(key, result))
- {
- return result;
- }
LookupRequest request = initLookupRequest(semantics, name, mask, options, nullptr);
_lookUpMembersInType(astBuilder, name, type, request, result, nullptr);
- typeCheckingCache->lookupCache[key] = result;
return result;
}
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index dc6067868..1e58a456e 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -3038,38 +3038,6 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
return info;
}
- LoweredValInfo visitDifferentiableDeclRefExpr(DifferentiableDeclRefExpr* expr)
- {
- LoweredValInfo info = lowerSubExpr(expr->inner);
-
- IRInst* irBaseVal = nullptr;
- switch (info.flavor)
- {
- case LoweredValInfo::Flavor::Simple:
- irBaseVal = getSimpleVal(context, info);
- break;
-
- case LoweredValInfo::Flavor::Ptr:
- irBaseVal = info.val;
- break;
-
- default:
- SLANG_UNEXPECTED("Unhandled lowered value cases");
- }
-
- // If the differentiable expr has an associated getter or setter, lower it
- // and put it in a decoration.
- //
- if (expr->getterExpr != nullptr)
- {
- auto irGetter = lowerSubExpr(expr->getterExpr);
- SLANG_ASSERT(irGetter.flavor == LoweredValInfo::Flavor::Simple);
- getBuilder()->addDifferentialGetterDecoration(irBaseVal, irGetter.val);
- }
-
- return info;
- }
-
// Emit IR to denote the forward-mode derivative
// of the inner func-expr. This will be resolved
// to a concrete function during the derivative
@@ -6319,7 +6287,13 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
// A variable declared inside of an aggregate type declaration is a member.
return true;
}
-
+ if (auto extDecl = as<ExtensionDecl>(parent))
+ {
+ if (auto declRefType = as<DeclRefType>(extDecl->targetType.type))
+ {
+ return true;
+ }
+ }
return false;
}
@@ -7108,6 +7082,14 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
builder->addDecoration(inst, op, operands.getBuffer(), operands.getCount());
}
+ void lowerDerivativeMemberModifier(IRInst* inst, DerivativeMemberAttribute* derivativeMember)
+ {
+ auto key = lowerRValueExpr(context, derivativeMember->memberDeclRef).val;
+ SLANG_RELEASE_ASSERT(as<IRStructKey>(key));
+ auto builder = getBuilder();
+ builder->addDecoration(inst, kIROp_JVPDerivativeMemberReferenceDecoration, key);
+ }
+
LoweredValInfo lowerMemberVarDecl(VarDecl* fieldDecl)
{
// Each field declaration in the AST translates into
@@ -7120,12 +7102,21 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
// will use the same space of keys.
auto builder = getBuilder();
- auto irFieldKey = builder->createStructKey();
- addNameHint(context, irFieldKey, fieldDecl);
+ IRInst* irFieldKey = nullptr;
+ if (auto extVarModifier = fieldDecl->findModifier<ExtensionExternVarModifier>())
+ {
+ irFieldKey = ensureDecl(context, extVarModifier->originalDecl.getDecl()).val;
+ SLANG_RELEASE_ASSERT(as<IRStructKey>(irFieldKey));
+ }
- addVarDecorations(context, irFieldKey, fieldDecl);
+ if (!irFieldKey)
+ {
+ irFieldKey = builder->createStructKey();
- addLinkageDecoration(context, irFieldKey, fieldDecl);
+ addNameHint(context, irFieldKey, fieldDecl);
+ addVarDecorations(context, irFieldKey, fieldDecl);
+ addLinkageDecoration(context, irFieldKey, fieldDecl);
+ }
if (auto semanticModifier = fieldDecl->findModifier<HLSLSimpleSemantic>())
{
@@ -7140,6 +7131,10 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
{
lowerRayPayloadAccessModifier(irFieldKey, writeModifier, kIROp_StageWriteAccessDecoration);
}
+ if (auto derivativeMemberModifier = fieldDecl->findModifier<DerivativeMemberAttribute>())
+ {
+ lowerDerivativeMemberModifier(irFieldKey, derivativeMemberModifier);
+ }
// We allow a field to be marked as a target intrinsic,
// so that we can override its mangled name in the
@@ -7815,6 +7810,16 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
getBuilder()->addJVPDerivativeMarkerDecoration(irFunc);
}
+ // Always force inline diff setter accessor to prevent downstream compiler from complaining
+ // fields are not fully initialized for the first `inout` parameter.
+ if (as<SetterDecl>(decl))
+ {
+ if (!decl->findModifier<ForceInlineAttribute>())
+ {
+ getBuilder()->addForceInlineDecoration(irFunc);
+ }
+ }
+
FuncDeclBaseTypeInfo info;
_lowerFuncDeclBaseTypeInfo(
subContext,