summaryrefslogtreecommitdiffstats
path: root/source/slang
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang')
-rw-r--r--source/slang/core.meta.slang2
-rw-r--r--source/slang/diff.meta.slang63
-rw-r--r--source/slang/slang-ast-base.h8
-rw-r--r--source/slang/slang-ast-builder.cpp19
-rw-r--r--source/slang/slang-ast-builder.h9
-rw-r--r--source/slang/slang-ast-decl.h3
-rw-r--r--source/slang/slang-ast-expr.h6
-rw-r--r--source/slang/slang-ast-modifier.h4
-rw-r--r--source/slang/slang-ast-support-types.h12
-rw-r--r--source/slang/slang-ast-synthesis.cpp175
-rw-r--r--source/slang/slang-ast-synthesis.h147
-rw-r--r--source/slang/slang-ast-type.cpp57
-rw-r--r--source/slang/slang-ast-type.h14
-rw-r--r--source/slang/slang-ast-val.cpp35
-rw-r--r--source/slang/slang-ast-val.h18
-rw-r--r--source/slang/slang-check-conformance.cpp42
-rw-r--r--source/slang/slang-check-conversion.cpp3
-rw-r--r--source/slang/slang-check-decl.cpp438
-rw-r--r--source/slang/slang-check-expr.cpp114
-rw-r--r--source/slang/slang-check-impl.h20
-rw-r--r--source/slang/slang-check-overload.cpp6
-rw-r--r--source/slang/slang-ir-diff-jvp.cpp72
-rw-r--r--source/slang/slang-ir-inst-defs.h4
-rw-r--r--source/slang/slang-ir.cpp14
-rw-r--r--source/slang/slang-lower-to-ir.cpp35
-rw-r--r--source/slang/slang-parser.cpp2
-rw-r--r--source/slang/slang-syntax.cpp67
-rw-r--r--source/slang/slang.natvis80
28 files changed, 1175 insertions, 294 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index 75bc65562..5df9d01fe 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -2434,7 +2434,7 @@ __generic<T : IComparable>
[__unsafeForceInlineEarly]
bool operator >=(T v0, T v1)
{
- return v1.lessThanOrEquals(v1);
+ return v1.lessThan(v1);
}
__generic<T : IComparable>
[__unsafeForceInlineEarly]
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang
index ad3dfe77c..674531048 100644
--- a/source/slang/diff.meta.slang
+++ b/source/slang/diff.meta.slang
@@ -128,13 +128,37 @@ extension vector<float, 4> : IDifferentiable
}
}
+__magic_type(DifferentialBottomType)
+__intrinsic_type($(kIROp_DifferentialBottomType))
+struct __DifferentialBottom : IDifferentiable
+{
+ typedef __DifferentialBottom Differential;
+
+ __intrinsic_op($(kIROp_DifferentialBottomValue))
+ static __DifferentialBottom dzero();
+
+ [__unsafeForceInlineEarly]
+ static __DifferentialBottom dadd(Differential a, Differential b)
+ {
+ return dzero();
+ }
+
+ [__unsafeForceInlineEarly]
+ static __DifferentialBottom dmul(This a, Differential b)
+ {
+ return dzero();
+ }
+}
+
/// Pair type that serves to wrap the primal and
/// differential types of an arbitrary type T.
__generic<T : IDifferentiable>
__magic_type(DifferentialPairType)
__intrinsic_type($(kIROp_DifferentialPairType))
-struct __DifferentialPair
+struct DifferentialPair : IDifferentiable
{
+ typedef DifferentialPair<T.Differential> Differential;
+ typedef T.Differential DifferentialElementType;
__intrinsic_op($(kIROp_MakeDifferentialPair))
__init(T _primal, T.Differential _differential);
@@ -154,6 +178,31 @@ struct __DifferentialPair
{
return p();
}
+
+ [__unsafeForceInlineEarly]
+ static Differential dzero()
+ {
+ return Differential(T.dzero(), Differential.DifferentialElementType.dzero());
+ }
+
+ [__unsafeForceInlineEarly]
+ static Differential dadd(Differential a, Differential b)
+ {
+ return Differential(
+ T.dadd(
+ a.p(),
+ b.p()
+ ),
+ Differential.DifferentialElementType.dzero());
+ }
+
+ [__unsafeForceInlineEarly]
+ static Differential dmul(This a, Differential b)
+ {
+ return Differential(
+ T.dmul(a.p(), b.p()),
+ Differential.DifferentialElementType.dzero());
+ }
};
typealias IDFloat = IFloat & IDifferentiable;
@@ -171,9 +220,9 @@ namespace dstd
T exp(T x);
__generic<T : IDFloat>
- __DifferentialPair<T> d_exp(__DifferentialPair<T> dpx)
+ DifferentialPair<T> d_exp(DifferentialPair<T> dpx)
{
- return __DifferentialPair<T>(
+ return DifferentialPair<T>(
exp(dpx.p()),
T.dmul(exp(dpx.p()), dpx.d()));
}
@@ -189,9 +238,9 @@ namespace dstd
T sin(T x);
__generic<T : IDFloat>
- __DifferentialPair<T> d_sin(__DifferentialPair<T> dpx)
+ DifferentialPair<T> d_sin(DifferentialPair<T> dpx)
{
- return __DifferentialPair<T>(
+ return DifferentialPair<T>(
sin(dpx.p()),
T.dmul(cos(dpx.p()), dpx.d()));
}
@@ -207,9 +256,9 @@ namespace dstd
T cos(T x);
__generic<T : IDFloat>
- __DifferentialPair<T> d_cos(__DifferentialPair<T> dpx)
+ DifferentialPair<T> d_cos(DifferentialPair<T> dpx)
{
- return __DifferentialPair<T>(
+ return DifferentialPair<T>(
cos(dpx.p()),
T.dmul(-sin(dpx.p()), dpx.d()));
}
diff --git a/source/slang/slang-ast-base.h b/source/slang/slang-ast-base.h
index dea02afbb..627a56152 100644
--- a/source/slang/slang-ast-base.h
+++ b/source/slang/slang-ast-base.h
@@ -254,7 +254,9 @@ private:
// The actual values of the arguments
List<Val* > args;
public:
+ List<Val*>& getArgs() { return args; }
const List<Val*>& getArgs() const { return args; }
+
// Overrides should be public so base classes can access
Substitutions* _applySubstitutionsShallowOverride(ASTBuilder* astBuilder, SubstitutionSet substSet, Substitutions* substOuter, int* ioDiff);
bool _equalsOverride(Substitutions* subst);
@@ -265,6 +267,12 @@ public:
genericDecl = decl;
}
+ GenericSubstitution(GenericDecl* decl, ArrayView<Val*> argVals)
+ {
+ genericDecl = decl;
+ args.addRange(argVals);
+ }
+
template<typename... TArgs>
GenericSubstitution(GenericDecl* decl, TArgs... inArgs)
{
diff --git a/source/slang/slang-ast-builder.cpp b/source/slang/slang-ast-builder.cpp
index f6c550d69..beee16f9c 100644
--- a/source/slang/slang-ast-builder.cpp
+++ b/source/slang/slang-ast-builder.cpp
@@ -141,6 +141,16 @@ Type* SharedASTBuilder::getNoneType()
return m_noneType;
}
+Type* SharedASTBuilder::getDifferentialBottomType()
+{
+ if (!m_diffBottomType)
+ {
+ auto diffBottomTypeDecl = findMagicDecl("DifferentialBottomType");
+ m_diffBottomType = DeclRefType::create(m_astBuilder, makeDeclRef<Decl>(diffBottomTypeDecl));
+ }
+ return m_diffBottomType;
+}
+
SharedASTBuilder::~SharedASTBuilder()
{
// Release built in types..
@@ -299,13 +309,18 @@ VectorExpressionType* ASTBuilder::getVectorType(
return result;
}
-DifferentialPairType* ASTBuilder::getDifferentialPairType(Type* valueType, Witness* conformanceWitness)
+DifferentialPairType* ASTBuilder::getDifferentialPairType(
+ Type* valueType,
+ Witness* primalIsDifferentialWitness)
{
auto genericDecl = dynamicCast<GenericDecl>(m_sharedASTBuilder->findMagicDecl("DifferentialPairType"));
auto typeDecl = genericDecl->inner;
- auto substitutions = getOrCreate<GenericSubstitution>(genericDecl, valueType, conformanceWitness);
+ auto substitutions = getOrCreate<GenericSubstitution>(
+ genericDecl,
+ valueType,
+ primalIsDifferentialWitness);
auto declRef = DeclRef<Decl>(typeDecl, substitutions);
auto rsType = DeclRefType::create(this, declRef);
diff --git a/source/slang/slang-ast-builder.h b/source/slang/slang-ast-builder.h
index e4ea872a0..235bebfaa 100644
--- a/source/slang/slang-ast-builder.h
+++ b/source/slang/slang-ast-builder.h
@@ -35,6 +35,8 @@ public:
Type* getNullPtrType();
/// Get the NullPtr type
Type* getNoneType();
+ /// Get the DifferentialBottom type.
+ Type* getDifferentialBottomType();
const ReflectClassInfo* findClassInfo(Name* name);
SyntaxClass<NodeBase> findSyntaxClass(Name* name);
@@ -79,7 +81,7 @@ protected:
Type* m_dynamicType = nullptr;
Type* m_nullPtrType = nullptr;
Type* m_noneType = nullptr;
-
+ Type* m_diffBottomType = nullptr;
Type* m_builtinTypes[Index(BaseType::CountOf)];
Dictionary<String, Decl*> m_magicDecls;
@@ -297,6 +299,7 @@ public:
Type* getOverloadedType() { return m_sharedASTBuilder->m_overloadedType; }
Type* getErrorType() { return m_sharedASTBuilder->m_errorType; }
Type* getBottomType() { return m_sharedASTBuilder->m_bottomType; }
+ Type* getDifferentialBottomType() { return m_sharedASTBuilder->getDifferentialBottomType(); }
Type* getStringType() { return m_sharedASTBuilder->getStringType(); }
Type* getNullPtrType() { return m_sharedASTBuilder->getNullPtrType(); }
Type* getNoneType() { return m_sharedASTBuilder->getNoneType(); }
@@ -326,7 +329,9 @@ public:
VectorExpressionType* getVectorType(Type* elementType, IntVal* elementCount);
- DifferentialPairType* getDifferentialPairType(Type* valueType, Witness* conformanceWitness);
+ DifferentialPairType* getDifferentialPairType(
+ Type* valueType,
+ Witness* primalIsDifferentialWitness);
DeclRef<InterfaceDecl> getDifferentiableInterface();
diff --git a/source/slang/slang-ast-decl.h b/source/slang/slang-ast-decl.h
index 87d696927..90175dd9d 100644
--- a/source/slang/slang-ast-decl.h
+++ b/source/slang/slang-ast-decl.h
@@ -29,6 +29,9 @@ class ContainerDecl: public Decl
List<Decl*> members;
SourceLoc closingSourceLoc;
+ // The associated scope owned by this decl.
+ Scope* ownedScope = nullptr;
+
template<typename T>
FilteredMemberList<T> getMembersOfType()
{
diff --git a/source/slang/slang-ast-expr.h b/source/slang/slang-ast-expr.h
index f2a72703e..ef6a05c71 100644
--- a/source/slang/slang-ast-expr.h
+++ b/source/slang/slang-ast-expr.h
@@ -126,6 +126,12 @@ class InitializerListExpr : public Expr
List<Expr*> args;
};
+class GetArrayLengthExpr : public Expr
+{
+ SLANG_AST_CLASS(GetArrayLengthExpr)
+ Expr* arrayExpr = nullptr;
+};
+
// A base class for expressions with arguments
class ExprWithArgsBase : public Expr
{
diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h
index 76106074f..0c1eb8d49 100644
--- a/source/slang/slang-ast-modifier.h
+++ b/source/slang/slang-ast-modifier.h
@@ -1016,14 +1016,14 @@ class RequiresNVAPIAttribute : public Attribute
};
/// The `[ForwardDifferentiable]` attribute indicates that a function can be forward-differentiated.
-class ForwardDifferentiableAttribute : public Attribute
+class ForwardDifferentiableAttribute : public DifferentiableAttribute
{
SLANG_AST_CLASS(ForwardDifferentiableAttribute)
};
/// The `[ForwardDerivative(function)]` attribute specifies a custom function that should
/// be used as the derivative for the decorated function.
-class ForwardDerivativeAttribute : public Attribute
+class ForwardDerivativeAttribute : public DifferentiableAttribute
{
SLANG_AST_CLASS(ForwardDerivativeAttribute)
diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h
index 4c92810d9..d4a781846 100644
--- a/source/slang/slang-ast-support-types.h
+++ b/source/slang/slang-ast-support-types.h
@@ -384,6 +384,12 @@ namespace Slang
///
ReadyForConformances,
+ /// Any DeclRefTypes with substitutions have been fully resolved
+ /// to concrete type. E.g. `T.X` with `T=A` should resolve to `A.X`.
+ /// We need a separate pass to resolve these types because `A.X`
+ /// maybe synthesized and made available only after conformance checking.
+ TypesFullyResolved,
+
/// The declaration is fully checked.
///
/// This step includes any validation of the declaration that is
@@ -779,6 +785,12 @@ namespace Slang
void toText(StringBuilder& out) const;
};
+ // If this is a declref to an associatedtype with a ThisTypeSubsitution,
+ // try to find the concrete decl that satisfies the associatedtype requirement from the
+ // concrete type supplied by ThisTypeSubstittution.
+ Val* _tryLookupConcreteAssociatedTypeFromThisTypeSubst(ASTBuilder* builder, DeclRef<Decl> declRef);
+
+
template<typename T>
struct DeclRef : DeclRefBase
{
diff --git a/source/slang/slang-ast-synthesis.cpp b/source/slang/slang-ast-synthesis.cpp
new file mode 100644
index 000000000..e5a9ff75f
--- /dev/null
+++ b/source/slang/slang-ast-synthesis.cpp
@@ -0,0 +1,175 @@
+#include "slang-ast-synthesis.h"
+
+namespace Slang
+{
+Expr* ASTSynthesizer::emitBinaryExpr(UnownedStringSlice operatorToken, Expr* left, Expr* right)
+{
+ auto infixExpr = m_builder->create<InfixExpr>();
+ infixExpr->functionExpr = emitVarExpr(m_namePool->getName(operatorToken));;
+ infixExpr->arguments.add(left);
+ infixExpr->arguments.add(right);
+ return infixExpr;
+}
+
+Expr* ASTSynthesizer::emitPrefixExpr(UnownedStringSlice operatorToken, Expr* base)
+{
+ auto prefixExpr = m_builder->create<PrefixExpr>();
+ prefixExpr->functionExpr = emitVarExpr(m_namePool->getName(operatorToken));;
+ prefixExpr->arguments.add(base);
+ return prefixExpr;
+}
+
+Expr* ASTSynthesizer::emitPostfixExpr(UnownedStringSlice operatorToken, Expr* base)
+{
+ auto postfixExpr = m_builder->create<PostfixExpr>();
+ postfixExpr->functionExpr = emitVarExpr(m_namePool->getName(operatorToken));;
+ postfixExpr->arguments.add(base);
+ return postfixExpr;
+}
+
+ForStmt* ASTSynthesizer::emitFor(Expr* initVal, Expr* finalVal, VarDecl* &outIndexVar)
+{
+ auto scopeDecl = pushVarScope()->containerDecl;
+ auto stmt = m_builder->create<ForStmt>();
+ stmt->scopeDecl = (ScopeDecl*)scopeDecl;
+ auto declStmt = emitVarDeclStmt(nullptr, m_namePool->getName("S_synth_loop_index"), initVal);
+ stmt->initialStatement = declStmt;
+ outIndexVar = (VarDecl*)declStmt->decl;
+ auto predicateExpr = emitBinaryExpr(UnownedStringSlice("<"), emitVarExpr(outIndexVar), finalVal);
+ stmt->predicateExpression = predicateExpr;
+ stmt->sideEffectExpression = emitPrefixExpr(UnownedStringSlice("++"), emitVarExpr(outIndexVar));
+ getCurrentScope().m_parentSeqStmt->stmts.add(stmt);
+ return stmt;
+}
+
+Expr* ASTSynthesizer::emitVarExpr(Name* name)
+{
+ auto scope = getCurrentScope();
+ SLANG_RELEASE_ASSERT(scope.m_scope);
+ auto varExpr = m_builder->create<VarExpr>();
+ varExpr->name = name;
+ varExpr->scope = scope.m_scope;
+ return varExpr;
+}
+
+Expr* ASTSynthesizer::emitVarExpr(VarDecl* varDecl)
+{
+ auto varExpr = m_builder->create<VarExpr>();
+ varExpr->declRef = makeDeclRef(varDecl);
+ varExpr->type = varDecl->type.type;
+ return varExpr;
+}
+
+Expr* ASTSynthesizer::emitVarExpr(VarDecl* var, Type* type)
+{
+ auto expr = m_builder->create<VarExpr>();
+ expr->declRef = DeclRef<Decl>(var, nullptr);
+ expr->type.type = type;
+ expr->type.isLeftValue = true;
+ return expr;
+}
+
+Expr* ASTSynthesizer::emitVarExpr(DeclStmt* varStmt, Type* type)
+{
+ auto expr = m_builder->create<VarExpr>();
+ expr->declRef = DeclRef<Decl>(as<Decl>(varStmt->decl), nullptr);
+ expr->type.type = type;
+ expr->type.isLeftValue = true;
+ return expr;
+}
+
+Expr* ASTSynthesizer::emitIntConst(int value)
+{
+ auto expr = m_builder->create<IntegerLiteralExpr>();
+ expr->type.type = m_builder->getIntType();
+ expr->value = value;
+ return expr;
+}
+
+Expr* ASTSynthesizer::emitGetArrayLengthExpr(Expr* arrayExpr)
+{
+ auto expr = m_builder->create<GetArrayLengthExpr>();
+ expr->arrayExpr = arrayExpr;
+ expr->type = m_builder->getIntType();
+ return expr;
+}
+
+Expr* ASTSynthesizer::emitMemberExpr(Expr* arrayExpr, Name* name)
+{
+ auto rs = m_builder->create<MemberExpr>();
+ rs->baseExpression = arrayExpr;
+ rs->name = name;
+ return rs;
+}
+
+Expr* ASTSynthesizer::emitAssignExpr(Expr* left, Expr* right)
+{
+ auto rs = m_builder->create<AssignExpr>();
+ rs->left = left;
+ rs->right = right;
+ return rs;
+}
+
+Expr* ASTSynthesizer::emitInvokeExpr(Expr* callee, List<Expr*>&& args)
+{
+ auto rs = m_builder->create<InvokeExpr>();
+ rs->functionExpr = callee;
+ rs->arguments = _Move(args);
+ return rs;
+}
+
+Expr* ASTSynthesizer::emitMemberExpr(Type* type, Name* name)
+{
+ auto rs = m_builder->create<StaticMemberExpr>();
+ auto typeExpr = m_builder->create<SharedTypeExpr>();
+ auto typetype = m_builder->create<TypeType>();
+ typetype->type = type;
+ typeExpr->type = typetype;
+ rs->baseExpression = typeExpr;
+ rs->name = name;
+ return rs;
+}
+
+Expr* ASTSynthesizer::emitIndexExpr(Expr* base, Expr* index)
+{
+ auto rs = m_builder->create<IndexExpr>();
+ rs->baseExpression = base;
+ rs->indexExprs.add(index);
+ return rs;
+}
+
+ExpressionStmt* ASTSynthesizer::emitExprStmt(Expr* expr)
+{
+ auto rs = m_builder->create<ExpressionStmt>();
+ _addStmtToScope(rs);
+ rs->expression = expr;
+ return rs;
+}
+
+ReturnStmt* ASTSynthesizer::emitReturnStmt(Expr* expr)
+{
+ auto rs = m_builder->create<ReturnStmt>();
+ rs->expression = expr;
+ _addStmtToScope(rs);
+ return rs;
+}
+
+DeclStmt* ASTSynthesizer::emitVarDeclStmt(Type* type, Name* name, Expr* initVal)
+{
+ auto scope = getCurrentScope();
+ SLANG_RELEASE_ASSERT(scope.m_parentSeqStmt);
+ SLANG_RELEASE_ASSERT(scope.m_scope);
+ SLANG_RELEASE_ASSERT(scope.m_scope->containerDecl);
+ auto varDecl = m_builder->create<VarDecl>();
+ varDecl->type.type = type;
+ varDecl->nameAndLoc.name = name;
+ varDecl->initExpr = initVal;
+ varDecl->parentDecl = scope.m_scope->containerDecl;
+ varDecl->parentDecl->members.add(varDecl);
+ auto stmt = m_builder->create<DeclStmt>();
+ stmt->decl = varDecl;
+ _addStmtToScope(stmt);
+ return stmt;
+}
+
+}
diff --git a/source/slang/slang-ast-synthesis.h b/source/slang/slang-ast-synthesis.h
new file mode 100644
index 000000000..2af890d34
--- /dev/null
+++ b/source/slang/slang-ast-synthesis.h
@@ -0,0 +1,147 @@
+// slang-ast-synthesis.h
+
+#pragma once
+
+#include "slang-syntax.h"
+
+namespace Slang {
+
+struct ASTEmitScope
+{
+ ContainerDecl* m_parent = nullptr;
+ SeqStmt* m_parentSeqStmt = nullptr;
+ Scope* m_scope = nullptr;
+};
+class ASTSynthesizer
+{
+private:
+ ASTBuilder* m_builder;
+ NamePool* m_namePool;
+ List<ASTEmitScope> m_scopeStack;
+public:
+ ASTSynthesizer(ASTBuilder* builder, NamePool* namePool)
+ : m_builder(builder)
+ , m_namePool(namePool)
+ {
+ }
+
+ Scope* getScope(ContainerDecl* decl)
+ {
+ for (auto container = decl; container; container = container->parentDecl)
+ {
+ if (container->ownedScope)
+ {
+ return container->ownedScope;
+ }
+ }
+ return nullptr;
+ }
+
+ // Create a scope for `decl` and push it to scope stack
+ void pushScopeForContainer(ContainerDecl* decl)
+ {
+ if (decl->ownedScope)
+ {
+ // if decl already owns a scope, don't create a new one.
+ pushContainerScope(decl);
+ return;
+ }
+
+ auto parentScope = getScope(decl);
+ decl->ownedScope = m_builder->create<Scope>();
+ decl->ownedScope->parent = parentScope;
+ pushContainerScope(decl);
+ }
+
+ // Push `decl` and its associated scope to scope stack
+ void pushContainerScope(ContainerDecl* decl)
+ {
+ ASTEmitScope scope = getCurrentScope();
+ scope.m_parent = decl;
+ scope.m_scope = getScope(decl);
+ m_scopeStack.add(scope);
+ }
+
+ Scope* pushVarScope()
+ {
+ ASTEmitScope scope = getCurrentScope();
+ auto scopeDecl = m_builder->create<ScopeDecl>();
+ auto newScope = m_builder->create<Scope>();
+ scopeDecl->parentDecl = scope.m_parent;
+ if (scope.m_parent)
+ scope.m_parent->members.add(scopeDecl);
+ newScope->parent = scope.m_scope;
+ newScope->containerDecl = scopeDecl;
+ scope.m_scope = newScope;
+ m_scopeStack.add(scope);
+ return newScope;
+ }
+
+ void _addStmtToScope(Stmt* stmt)
+ {
+ auto scope = getCurrentScope();
+ if (scope.m_parentSeqStmt)
+ {
+ scope.m_parentSeqStmt->stmts.add(stmt);
+ }
+ }
+
+ SeqStmt* pushSeqStmtScope()
+ {
+ ASTEmitScope scope = getCurrentScope();
+ scope.m_parentSeqStmt = m_builder->create<SeqStmt>();
+ m_scopeStack.add(scope);
+ return scope.m_parentSeqStmt;
+ }
+
+ void popScope()
+ {
+ m_scopeStack.removeLast();
+ }
+
+ ASTEmitScope getCurrentScope()
+ {
+ if (m_scopeStack.getCount())
+ return m_scopeStack.getLast();
+ return ASTEmitScope();
+ }
+
+ ForStmt* emitFor(Expr* initVal, Expr* finalVal, VarDecl*& outIndexVar);
+
+ Expr* emitBinaryExpr(UnownedStringSlice operatorToken, Expr* left, Expr* right);
+
+ Expr* emitPrefixExpr(UnownedStringSlice operatorToken, Expr* base);
+
+ Expr* emitPostfixExpr(UnownedStringSlice operatorToken, Expr* base);
+
+ Expr* emitVarExpr(Name* name);
+ Expr* emitVarExpr(VarDecl* var);
+ Expr* emitVarExpr(VarDecl* var, Type* type);
+ Expr* emitVarExpr(DeclStmt* varStmt, Type* type);
+
+ Expr* emitIntConst(int value);
+
+ Expr* emitGetArrayLengthExpr(Expr* arrayExpr);
+
+ Expr* emitMemberExpr(Expr* base, Name* name);
+ Expr* emitMemberExpr(Type* base, Name* name);
+
+ Expr* emitIndexExpr(Expr* base, Expr* index);
+
+ Expr* emitAssignExpr(Expr* left, Expr* right);
+ ExpressionStmt* emitAssignStmt(Expr* left, Expr* right)
+ {
+ return emitExprStmt(emitAssignExpr(left, right));
+ }
+
+ Expr* emitInvokeExpr(Expr* callee, List<Expr*>&& args);
+
+ DeclStmt* emitVarDeclStmt(Type* type, Name* name = nullptr, Expr* initVal = nullptr);
+
+ ExpressionStmt* emitExprStmt(Expr* expr);
+
+ ReturnStmt* emitReturnStmt(Expr* expr);
+
+};
+
+} // namespace Slang
diff --git a/source/slang/slang-ast-type.cpp b/source/slang/slang-ast-type.cpp
index 480589af4..a869c95a7 100644
--- a/source/slang/slang-ast-type.cpp
+++ b/source/slang/slang-ast-type.cpp
@@ -169,6 +169,27 @@ Val* BottomType::_substituteImplOverride(
HashCode BottomType::_getHashCodeOverride() { return HashCode(size_t(this)); }
+// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! DifferentialBottomType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
+
+void DifferentialBottomType::_toTextOverride(StringBuilder& out) { out << toSlice("diff_bottom"); }
+
+bool DifferentialBottomType::_equalsImplOverride(Type* type)
+{
+ if (auto bottomType = as<DifferentialBottomType>(type))
+ return true;
+ return false;
+}
+
+Type* DifferentialBottomType::_createCanonicalTypeOverride() { return this; }
+
+Val* DifferentialBottomType::_substituteImplOverride(
+ ASTBuilder* /* astBuilder */, SubstitutionSet /*subst*/, int* /*ioDiff*/)
+{
+ return this;
+}
+
+HashCode DifferentialBottomType::_getHashCodeOverride() { return HashCode(size_t(this)); }
+
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! DeclRefType !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
void DeclRefType::_toTextOverride(StringBuilder& out)
@@ -193,6 +214,7 @@ bool DeclRefType::_equalsImplOverride(Type * type)
Type* DeclRefType::_createCanonicalTypeOverride()
{
// A declaration reference is already canonical
+ declRef.substitute(this->getASTBuilder(), this);
return this;
}
@@ -223,39 +245,8 @@ Val* DeclRefType::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSe
// the outer interface, then try to replace the type with the
// actual value of the associated type for the given implementation.
//
- if (auto substAssocTypeDecl = as<AssocTypeDecl>(substDeclRef.decl))
- {
- for (auto s = substDeclRef.substitutions.substitutions; s; s = s->outer)
- {
- auto thisSubst = as<ThisTypeSubstitution>(s);
- if (!thisSubst)
- continue;
-
- if (auto interfaceDecl = as<InterfaceDecl>(substAssocTypeDecl->parentDecl))
- {
- if (thisSubst->interfaceDecl == interfaceDecl)
- {
- // We need to look up the declaration that satisfies
- // the requirement named by the associated type.
- Decl* requirementKey = substAssocTypeDecl;
- RequirementWitness requirementWitness = tryLookUpRequirementWitness(astBuilder, thisSubst->witness, requirementKey);
- switch (requirementWitness.getFlavor())
- {
- default:
- // No usable value was found, so there is nothing we can do.
- break;
-
- case RequirementWitness::Flavor::val:
- {
- auto satisfyingVal = requirementWitness.getVal();
- return satisfyingVal;
- }
- break;
- }
- }
- }
- }
- }
+ if (auto satisfyingVal = _tryLookupConcreteAssociatedTypeFromThisTypeSubst(astBuilder, substDeclRef))
+ return satisfyingVal;
// Re-construct the type in case we are using a specialized sub-class
return DeclRefType::create(astBuilder, substDeclRef);
diff --git a/source/slang/slang-ast-type.h b/source/slang/slang-ast-type.h
index 895b64f35..c7ce21cb0 100644
--- a/source/slang/slang-ast-type.h
+++ b/source/slang/slang-ast-type.h
@@ -86,6 +86,20 @@ protected:
{}
};
+// The bottom/empty type as a result of Differentiating a Differential.
+class DifferentialBottomType : public DeclRefType
+{
+ SLANG_AST_CLASS(DifferentialBottomType)
+
+ // Overrides should be public so base classes can access
+ void _toTextOverride(StringBuilder& out);
+ Type* _createCanonicalTypeOverride();
+ bool _equalsImplOverride(Type* type);
+ HashCode _getHashCodeOverride();
+ Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
+};
+
+
// Base class for types that can be used in arithmetic expressions
class ArithmeticExpressionType : public DeclRefType
{
diff --git a/source/slang/slang-ast-val.cpp b/source/slang/slang-ast-val.cpp
index a8ceaa716..87e89ef18 100644
--- a/source/slang/slang-ast-val.cpp
+++ b/source/slang/slang-ast-val.cpp
@@ -618,6 +618,41 @@ Val* TaggedUnionSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder,
return substWitness;
}
+bool DifferentialBottomSubtypeWitness::_equalsValOverride(Val* val)
+{
+ auto otherDiffBottomWitness = as<DifferentialBottomSubtypeWitness>(val);
+ if (!otherDiffBottomWitness)
+ return false;
+
+ return otherDiffBottomWitness->sub && otherDiffBottomWitness->sub->equals(sub);
+}
+
+void DifferentialBottomSubtypeWitness::_toTextOverride(StringBuilder& out)
+{
+ out << "DifferentialBottomSubtypeWitness(" << sub << ")";
+}
+
+HashCode DifferentialBottomSubtypeWitness::_getHashCodeOverride()
+{
+ return combineHash(3892, sub->getHashCode());
+}
+
+Val* DifferentialBottomSubtypeWitness::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff)
+{
+ int diff = 0;
+
+ auto substSub = as<Type>(sub->substituteImpl(astBuilder, subst, &diff));
+ auto substSup = as<Type>(sup->substituteImpl(astBuilder, subst, &diff));
+ if (!diff)
+ return this;
+
+ *ioDiff += diff;
+
+ DifferentialBottomSubtypeWitness* substWitness =
+ astBuilder->create<DifferentialBottomSubtypeWitness>(substSub, substSup);
+ return substWitness;
+}
+
bool ConjunctionSubtypeWitness::_equalsValOverride(Val* val)
{
if (auto other = as<ConjunctionSubtypeWitness>(val))
diff --git a/source/slang/slang-ast-val.h b/source/slang/slang-ast-val.h
index 5f72e58c8..b52984f8b 100644
--- a/source/slang/slang-ast-val.h
+++ b/source/slang/slang-ast-val.h
@@ -399,6 +399,24 @@ class DynamicSubtypeWitness : public SubtypeWitness
SLANG_AST_CLASS(DynamicSubtypeWitness)
};
+ /// A witness of the fact that any type can be viewed as a subtype of DifferentialBottom.
+class DifferentialBottomSubtypeWitness : public SubtypeWitness
+{
+ SLANG_AST_CLASS(DifferentialBottomSubtypeWitness)
+
+ DifferentialBottomSubtypeWitness(Type* inSub, Type* inSup)
+ {
+ sub = inSub;
+ sup = inSup;
+ }
+
+ // Overrides should be public so base classes can access
+ bool _equalsValOverride(Val* val);
+ void _toTextOverride(StringBuilder& out);
+ HashCode _getHashCodeOverride();
+ Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff);
+};
+
/// A witness that `T : L & R` because `T : L` and `T : R`
class ConjunctionSubtypeWitness : public SubtypeWitness
{
diff --git a/source/slang/slang-check-conformance.cpp b/source/slang/slang-check-conformance.cpp
index 2c9977082..eb072e9dd 100644
--- a/source/slang/slang-check-conformance.cpp
+++ b/source/slang/slang-check-conformance.cpp
@@ -33,7 +33,6 @@ namespace Slang
else
return conjunction->rightWitness;
}
-
ExtractFromConjunctionSubtypeWitness* simplExtractFromConjunction = builder->create<ExtractFromConjunctionSubtypeWitness>();
simplExtractFromConjunction->sub = extractFromConjunction->sub;
simplExtractFromConjunction->sup = extractFromConjunction->sup;
@@ -145,7 +144,7 @@ namespace Slang
m_astBuilder->getOrCreate<DeclaredSubtypeWitness>(
bb->sub, bb->sup, bb->declRef.decl, bb->declRef.substitutions.substitutions);
- TransitiveSubtypeWitness* transitiveWitness = m_astBuilder->getOrCreateWithDefaultCtor<TransitiveSubtypeWitness>(subType, bb->sup, declaredWitness);
+ TransitiveSubtypeWitness* transitiveWitness = m_astBuilder->getOrCreateWithDefaultCtor<TransitiveSubtypeWitness>();
transitiveWitness->sub = subType;
transitiveWitness->sup = bb->sup;
transitiveWitness->midToSup = declaredWitness;
@@ -379,6 +378,45 @@ namespace Slang
}
}
}
+
+ // If a generic type parameter does not declare itself to conform to `IDifferentiable`,
+ // we treat it as a subtype of `DifferentialBottom` to make it conform to `IDifferentiable`.
+ // Note: we only consider this option for `originalSubType` so a type that implements `IDifferential` but
+ // inherits from some other non differentiable types don't get to inherit `DifferentialBottom`.
+ if (m_astBuilder->isDifferentiableInterfaceAvailable() &&
+ subType == originalSubType &&
+ superTypeDeclRef.getDecl() == m_astBuilder->getDifferentiableInterface())
+ {
+ if (as<GenericTypeParamDecl>(declRefType->declRef.getDecl()) ||
+ as<AssocTypeDecl>(declRefType->declRef.getDecl()))
+ {
+ auto sup = DeclRefType::create(m_astBuilder, superTypeDeclRef);
+ auto differentialBottomType = as<DeclRefType>(m_astBuilder->getDifferentialBottomType());
+ auto container = differentialBottomType->declRef.as<ContainerDecl>().getDecl();
+ SLANG_RELEASE_ASSERT(container);
+ auto inheritanceDecl = container->getMembersOfType<InheritanceDecl>().getFirst();
+ auto witnessDifferentialBottomIsIDifferentiable =
+ m_astBuilder->getOrCreate<DeclaredSubtypeWitness>(
+ m_astBuilder->getDifferentialBottomType(),
+ sup,
+ inheritanceDecl,
+ nullptr);
+
+ auto witnessSubIsDifferentialBottom =
+ m_astBuilder->getOrCreate<DifferentialBottomSubtypeWitness>(
+ subType, differentialBottomType);
+
+ TransitiveSubtypeWitness* transitiveWitness =
+ m_astBuilder->getOrCreateWithDefaultCtor<TransitiveSubtypeWitness>(
+ witnessSubIsDifferentialBottom, witnessDifferentialBottomIsIDifferentiable);
+ transitiveWitness->sub = subType;
+ transitiveWitness->sup = sup;
+ transitiveWitness->midToSup = witnessDifferentialBottomIsIDifferentiable;
+ transitiveWitness->subToMid = witnessSubIsDifferentialBottom;
+ *outWitness = transitiveWitness;
+ return true;
+ }
+ }
}
else if (auto extractExistentialType = as<ExtractExistentialType>(subType))
{
diff --git a/source/slang/slang-check-conversion.cpp b/source/slang/slang-check-conversion.cpp
index e56d63f91..5e84e170b 100644
--- a/source/slang/slang-check-conversion.cpp
+++ b/source/slang/slang-check-conversion.cpp
@@ -1169,9 +1169,6 @@ namespace Slang
fromExpr);
}
- // If we coerced to a differentiable type, log it.
- maybeRegisterDifferentiableType(m_astBuilder, expr->type);
-
return expr;
}
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 457ae229b..f60fbcc2c 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -11,9 +11,8 @@
// and when things get checked.
#include "slang-lookup.h"
-
#include "slang-syntax.h"
-
+#include "slang-ast-synthesis.h"
#include <limits>
namespace Slang
@@ -166,6 +165,65 @@ namespace Slang
void visitExtensionDecl(ExtensionDecl* decl);
};
+ struct SemanticsDeclTypeResolutionVisitor
+ : public SemanticsDeclVisitorBase
+ , public DeclVisitor<SemanticsDeclTypeResolutionVisitor>
+ {
+ SemanticsDeclTypeResolutionVisitor(SemanticsContext const& outer)
+ : SemanticsDeclVisitorBase(outer)
+ {}
+
+ void visitDecl(Decl*) {}
+ void visitDeclGroup(DeclGroup*) {}
+
+ Val* resolveVal(Val* val);
+ Type* resolveType(Type* type)
+ {
+ return (Type*)resolveVal(type);
+ }
+
+ void visitTypeExp(TypeExp& exp)
+ {
+ exp.type = resolveType(exp.type);
+ }
+
+ void visitVarDeclBase(VarDeclBase* varDecl)
+ {
+ visitTypeExp(varDecl->type);
+ }
+
+ void visitGenericTypeConstraintDecl(GenericTypeConstraintDecl* decl)
+ {
+ visitTypeExp(decl->sup);
+ }
+
+ void visitTypeDefDecl(TypeDefDecl* decl)
+ {
+ visitTypeExp(decl->type);
+ }
+
+ void visitGenericTypeParamDecl(GenericTypeParamDecl* paramDecl)
+ {
+ visitTypeExp(paramDecl->initType);
+ }
+
+ void visitInheritanceDecl(InheritanceDecl* inheritanceDecl)
+ {
+ visitTypeExp(inheritanceDecl->base);
+ }
+
+ void visitCallableDecl(CallableDecl* decl)
+ {
+ visitTypeExp(decl->returnType);
+ visitTypeExp(decl->errorType);
+ }
+
+ void visitPropertyDecl(PropertyDecl* decl)
+ {
+ visitTypeExp(decl->type);
+ }
+ };
+
struct SemanticsDeclBodyVisitor
: public SemanticsDeclVisitorBase
, public DeclVisitor<SemanticsDeclBodyVisitor>
@@ -1363,27 +1421,30 @@ namespace Slang
bool SemanticsVisitor::trySynthesizeDifferentialAssociatedTypeRequirementWitness(
ConformanceCheckingContext* context,
- DeclRef<Decl> requirementDeclRef,
+ DeclRef<AssocTypeDecl> requirementDeclRef,
RefPtr<WitnessTable> witnessTable)
{
- // We currently can't handle generic types.
- if (GetOuterGeneric(context->parentDecl) != nullptr)
- {
- return false;
- }
-
+ ASTSynthesizer synth(m_astBuilder, getNamePool());
Decl* existingDecl = nullptr;
AggTypeDecl* aggTypeDecl = nullptr;
if (context->parentDecl->getMemberDictionary().TryGetValue(requirementDeclRef.getName(), existingDecl))
{
- aggTypeDecl = as<AggTypeDecl>(existingDecl);
- SLANG_RELEASE_ASSERT(aggTypeDecl);
-
// Remove the `ToBeSynthesizedModifier`.
- if (as<ToBeSynthesizedModifier>(aggTypeDecl->modifiers.first))
+ if (as<ToBeSynthesizedModifier>(existingDecl->modifiers.first))
{
- aggTypeDecl->modifiers.first = aggTypeDecl->modifiers.first->next;
+ existingDecl->modifiers.first = existingDecl->modifiers.first->next;
}
+ else
+ {
+ // The user has defined an associatedtype explicitly but that we reach here because
+ // that type failed to satisfy the `IDifferential` requirement.
+ // We stop the synthesis and let the follow-up logic to report a diagnostic.
+ return false;
+ }
+
+ aggTypeDecl = as<AggTypeDecl>(existingDecl);
+ SLANG_RELEASE_ASSERT(aggTypeDecl);
+ synth.pushContainerScope(aggTypeDecl);
}
else
{
@@ -1393,15 +1454,12 @@ namespace Slang
aggTypeDecl->nameAndLoc.name = requirementDeclRef.getName();
aggTypeDecl->loc = context->parentDecl->nameAndLoc.loc;
context->parentDecl->invalidateMemberDictionary();
+ synth.pushScopeForContainer(aggTypeDecl);
}
- // TODO: if we want to make the synthesized type itself to be differentiable,
- // add an inheritance decl here. Need to be careful to avoid infinite recursion
- // trying to synthesize the higher order differential types.
-
// Helper function to add a `diffType` field into the synthesized type for the original
// `member`.
- auto differentialType = GetTypeForDeclRef(makeDeclRef(aggTypeDecl), context->parentDecl->loc);
+ auto differentialType = DeclRefType::create(m_astBuilder, makeDeclRef(aggTypeDecl));
auto addDiffMember = [&](Decl* member, Type* diffMemberType)
{
// If the field is differentiable, add a corresponding field in the associated Differential type.
@@ -1452,12 +1510,35 @@ namespace Slang
addDiffMember(member, diffType);
}
- // In the future when the Differential type itself needs to conform to some interface,
- // this is the place to synthesize requirements for them.
addModifier(aggTypeDecl, m_astBuilder->create<SynthesizedModifier>());
- auto satisfyingType = m_astBuilder->getOrCreateDeclRefType(aggTypeDecl, nullptr);
- witnessTable->add(requirementDeclRef.getDecl(), RequirementWitness(satisfyingType));
- return true;
+
+ // If `This` is nested inside a generic, we need to form a complete declref type to the
+ // newly synthesized aggTypeDecl here. This can be done by obtaining ThisTypeSubstitution
+ // from requirementDeclRef to get the generic substitution for outer generic parameters, and
+ // apply it to the newly synthesized decl.
+ SubstitutionSet substSet;
+ if (auto thisTypeSusbt = findThisTypeSubstitution(
+ requirementDeclRef.substitutions,
+ as<InterfaceDecl>(requirementDeclRef.getDecl()->parentDecl)))
+ {
+ if (auto declRefType = as<DeclRefType>(thisTypeSusbt->witness->sub))
+ {
+ substSet = declRefType->declRef.substitutions;
+ }
+ }
+
+ auto satisfyingType = m_astBuilder->getOrCreateDeclRefType(aggTypeDecl, substSet);
+
+ if (doesTypeSatisfyAssociatedTypeConstraintRequirement(satisfyingType, requirementDeclRef, witnessTable))
+ {
+ witnessTable->add(requirementDeclRef.getDecl(), RequirementWitness(satisfyingType));
+ return true;
+ }
+
+ // Note: the call to `doesTypeSatisfyAssociatedTypeConstraintRequirement` should always succeed.
+ // If not, there is something wrong with the code synthesis logic. For now we just return false
+ // instead of crashing so the user can work around the issues.
+ return false;
}
void SemanticsVisitor::tryAddDifferentiableConformanceToContext(Decl* decl, DifferentiableTypeSemanticContext*)
@@ -2242,22 +2323,8 @@ namespace Slang
witnessTable);
}
- bool SemanticsVisitor::doesTypeSatisfyAssociatedTypeRequirement(
- Type* satisfyingType,
- DeclRef<AssocTypeDecl> requiredAssociatedTypeDeclRef,
- RefPtr<WitnessTable> witnessTable)
+ bool SemanticsVisitor::doesTypeSatisfyAssociatedTypeConstraintRequirement(Type* satisfyingType, DeclRef<AssocTypeDecl> requiredAssociatedTypeDeclRef, RefPtr<WitnessTable> witnessTable)
{
- if (auto declRefType = as<DeclRefType>(satisfyingType))
- {
- // If we are seeing a placeholder that awaits synthesis, return false now to trigger
- // auto synthesis.
- if (declRefType->declRef.getDecl()->hasModifier<ToBeSynthesizedModifier>())
- return false;
- }
- // We need to confirm that the chosen type `satisfyingType`,
- // meets all the constraints placed on the associated type
- // requirement `requiredAssociatedTypeDeclRef`.
- //
// We will enumerate the type constraints placed on the
// associated type and see if they can be satisfied.
//
@@ -2269,7 +2336,7 @@ namespace Slang
// Perform a search for a witness to the subtype relationship.
auto witness = tryGetSubtypeWitness(satisfyingType, requiredSuperType);
- if(witness)
+ if (witness)
{
// If a subtype witness was found, then the conformance
// appears to hold, and we can satisfy that requirement.
@@ -2282,6 +2349,30 @@ namespace Slang
conformance = false;
}
}
+ return conformance;
+ }
+
+ bool SemanticsVisitor::doesTypeSatisfyAssociatedTypeRequirement(
+ Type* satisfyingType,
+ DeclRef<AssocTypeDecl> requiredAssociatedTypeDeclRef,
+ RefPtr<WitnessTable> witnessTable)
+ {
+ if (auto declRefType = as<DeclRefType>(satisfyingType))
+ {
+ // If we are seeing a placeholder that awaits synthesis, return false now to trigger
+ // auto synthesis.
+ if (declRefType->declRef.getDecl()->hasModifier<ToBeSynthesizedModifier>())
+ return false;
+ }
+ // We need to confirm that the chosen type `satisfyingType`,
+ // meets all the constraints placed on the associated type
+ // requirement `requiredAssociatedTypeDeclRef`.
+ //
+ // We will enumerate the type constraints placed on the
+ // associated type and see if they can be satisfied.
+ //
+ bool conformance = doesTypeSatisfyAssociatedTypeConstraintRequirement(
+ satisfyingType, requiredAssociatedTypeDeclRef, witnessTable);
// TODO: if any conformance check failed, we should probably include
// that in an error message produced about not satisfying the requirement.
@@ -3122,12 +3213,43 @@ namespace Slang
return false;
}
+ Stmt* _synthesizeMemberAssignMemberHelper(ASTSynthesizer& synth, Name* funcName, Type* leftType, Expr* leftValue, List<Expr*>&& args, int nestingLevel = 0)
+ {
+ if (nestingLevel > 16)
+ return nullptr;
+
+ // If field type is an array, assign each element individually.
+ if (auto arrayType = as<ArrayExpressionType>(leftType))
+ {
+ VarDecl* indexVar = nullptr;
+ auto forStmt = synth.emitFor(synth.emitIntConst(0), synth.emitGetArrayLengthExpr(leftValue), indexVar);
+ auto innerLeft = synth.emitIndexExpr(leftValue, synth.emitVarExpr(indexVar));
+ for (auto& arg : args)
+ {
+ arg = synth.emitIndexExpr(arg, synth.emitVarExpr(indexVar));
+ }
+ auto assignStmt = _synthesizeMemberAssignMemberHelper(synth, funcName, arrayType->baseType, innerLeft, _Move(args), nestingLevel + 1);
+ synth.popScope();
+ if (!assignStmt)
+ return nullptr;
+ forStmt->statement = assignStmt;
+ return forStmt;
+ }
+
+ auto callee = synth.emitMemberExpr(leftType, funcName);
+ return synth.emitAssignStmt(leftValue, synth.emitInvokeExpr(callee, _Move(args)));
+ }
+
bool SemanticsVisitor::trySynthesizeDifferentialMethodRequirementWitness(
ConformanceCheckingContext* context,
DeclRef<Decl> requirementDeclRef,
RefPtr<WitnessTable> witnessTable)
{
- // This method implements a general code synthesis pattern.
+ // We support two cases of synthesis here.
+ // Case 1 is that there the associated Differential type is defined to be `DifferentialBottom`.
+ // In this case we just trivially return `DifferentialBottom` in all synthesized methods.
+ // Case 2 is that the `Differential` type contains members corresponding to each primal member.
+ // We will apply a general code synthesis pattern to reflect that structure.
// For requirement of the form:
// ```
// static TResult requiredMethod(TParam1 p0, TParam2 p1, ...)
@@ -3145,104 +3267,123 @@ namespace Slang
// return result;
// }
// ```
+
+ // First we need to make sure the associated `Differential` type requirement is satisfied.
+ bool hasDifferentialAssocType = false;
+ for (auto existingEntry : witnessTable->requirementList)
+ {
+ if (auto builtinReqAttr = existingEntry.Key->findModifier<BuiltinRequirementAttribute>())
+ {
+ if (builtinReqAttr->kind == BuiltinRequirementKind::DifferentialType &&
+ existingEntry.Value.getFlavor() != RequirementWitness::Flavor::none)
+ {
+ hasDifferentialAssocType = true;
+ }
+ }
+ }
+ if (!hasDifferentialAssocType)
+ return false;
+
+ ASTSynthesizer synth(m_astBuilder, getNamePool());
List<Expr*> synArgs;
ThisExpr* synThis = nullptr;
auto synFunc = synthesizeMethodSignatureForRequirementWitness(
context, requirementDeclRef.as<FuncDecl>(), synArgs, synThis);
-
+ synFunc->parentDecl = context->parentDecl;
+ synth.pushContainerScope(synFunc);
auto blockStmt = m_astBuilder->create<BlockStmt>();
synFunc->body = blockStmt;
- auto seqStmt = m_astBuilder->create<SeqStmt>();
+ auto seqStmt = synth.pushSeqStmtScope();
blockStmt->body = seqStmt;
- // Create a variable for return value.
- auto scopeDecl = m_astBuilder->create<ScopeDecl>();
- synFunc->members.add(scopeDecl);
- scopeDecl->parentDecl = synFunc;
- auto varStmt = m_astBuilder->create<DeclStmt>();
- seqStmt->stmts.add(varStmt);
-
- auto returnVar = m_astBuilder->create<VarDecl>();
- returnVar->parentDecl = scopeDecl;
- scopeDecl->members.add(returnVar);
-
- returnVar->type.type = synFunc->returnType.type;
- returnVar->nameAndLoc.name = getName("result");
- varStmt->decl = returnVar;
- auto resultVarExpr = m_astBuilder->create<VarExpr>();
- resultVarExpr->declRef = makeDeclRef(returnVar);
- resultVarExpr->type.type = synFunc->returnType.type;
- resultVarExpr->type.isLeftValue = true;
-
- for (auto member : context->parentDecl->members)
- {
- auto derivativeAttr = member->findModifier<DerivativeMemberAttribute>();
- if (!derivativeAttr)
- continue;
- auto varMember = as<VarDeclBase>(member);
- if (!varMember)
- continue;
- ensureDecl(varMember, DeclCheckState::ReadyForReference);
- auto memberType = varMember->getType();
- auto diffMemberType = tryGetDifferentialType(m_astBuilder, memberType);
- if (!diffMemberType)
- continue;
+ if (synFunc->returnType.type->equals(m_astBuilder->getDifferentialBottomType()))
+ {
+ // Trivial case, the `Differential` type is `DifferentialBottom`.
+ // We will just return `DifferentialBottom.dzero()`.
+ auto resultExpr = m_astBuilder->create<InvokeExpr>();
+ auto dzeroMember = m_astBuilder->create<StaticMemberExpr>();
+ auto base = m_astBuilder->create<SharedTypeExpr>();
+ auto typetype = m_astBuilder->create<TypeType>();
+ typetype->type = m_astBuilder->getDifferentialBottomType();
+ base->type.type = typetype;
+ dzeroMember->baseExpression = base;
+ dzeroMember->name = getName("dzero");
+ resultExpr->functionExpr = dzeroMember;
+ auto synReturn = m_astBuilder->create<ReturnStmt>();
+ synReturn->expression = resultExpr;
+ seqStmt->stmts.add(synReturn);
+ }
+ else
+ {
+ // The general case.
+ // Create a variable for return value.
+ synth.pushVarScope();
+ auto varStmt = synth.emitVarDeclStmt(synFunc->returnType.type, getName("result"));
+ auto resultVarExpr = synth.emitVarExpr(varStmt, synFunc->returnType.type);
- // Construct reference exprs to the member's corresponding fields in each parameter.
- List<Expr*> paramFields;
- int paramIndex = 0;
- for (auto arg : synArgs)
- {
- auto memberExpr = m_astBuilder->create<MemberExpr>();
- memberExpr->baseExpression = arg;
- // TODO: we should probably fetch the name from `[DerivativeMember]` if `arg` is
- // Differential type.
- memberExpr->name = varMember->getName();
- paramFields.add(memberExpr);
- paramIndex++;
- }
-
- // Invoke the method for the field.
- auto callee = m_astBuilder->create<StaticMemberExpr>();
- auto baseSharedType = m_astBuilder->create<SharedTypeExpr>();
- auto baseSharedTypeType = m_astBuilder->create<TypeType>();
- baseSharedTypeType->type = memberType;
- baseSharedType->type = baseSharedTypeType;
- baseSharedType->base.type = memberType;
- callee->baseExpression = baseSharedType;
- callee->name = requirementDeclRef.getName();
- callee->loc = synFunc->loc;
- auto invokeExpr = m_astBuilder->create<InvokeExpr>();
- invokeExpr->functionExpr = callee;
- invokeExpr->arguments = _Move(paramFields);
-
- // Assign the value to resultVar.
- auto leftVal = m_astBuilder->create<MemberExpr>();
- leftVal->baseExpression = resultVarExpr;
- // TODO: we should probably fetch the name from `[DerivativeMember]` if `resultVarExpr`
- // is Differential type.
- leftVal->name = varMember->getName();
-
- auto assignExpr = m_astBuilder->create<AssignExpr>();
- assignExpr->left = leftVal;
- assignExpr->right = invokeExpr;
- auto assignStmt = m_astBuilder->create<ExpressionStmt>();
- assignStmt->expression = assignExpr;
- seqStmt->stmts.add(assignStmt);
- }
-
- // TODO: synthesize assignments for inherited members here.
-
- auto synReturn = m_astBuilder->create<ReturnStmt>();
- synReturn->expression = resultVarExpr;
- seqStmt->stmts.add(synReturn);
+ for (auto member : context->parentDecl->members)
+ {
+ auto derivativeAttr = member->findModifier<DerivativeMemberAttribute>();
+ if (!derivativeAttr)
+ continue;
+ auto varMember = as<VarDeclBase>(member);
+ if (!varMember)
+ continue;
+ ensureDecl(varMember, DeclCheckState::ReadyForReference);
+ auto memberType = varMember->getType();
+ auto diffMemberType = tryGetDifferentialType(m_astBuilder, memberType);
+ if (!diffMemberType)
+ continue;
- synFunc->parentDecl = context->parentDecl;
+ // Construct reference exprs to the member's corresponding fields in each parameter.
+ List<Expr*> paramFields;
+ int paramIndex = 0;
+ for (auto arg : synArgs)
+ {
+ auto memberExpr = m_astBuilder->create<MemberExpr>();
+ memberExpr->baseExpression = arg;
+ // TODO: we should probably fetch the name from `[DerivativeMember]` if `arg` is
+ // Differential type.
+ memberExpr->name = varMember->getName();
+ paramFields.add(memberExpr);
+ paramIndex++;
+ }
+
+ // Invoke the method for the field and assign the value to resultVar.
+ // TODO: we should probably fetch the name from `[DerivativeMember]` if `resultVarExpr`
+ // is Differential type.
+ auto leftVal = synth.emitMemberExpr(resultVarExpr, varMember->getName());
+ if (!_synthesizeMemberAssignMemberHelper(synth, requirementDeclRef.getName(), memberType, leftVal, _Move(paramFields)))
+ return false;
+ }
+
+ // TODO: synthesize assignments for inherited members here.
+
+ auto synReturn = m_astBuilder->create<ReturnStmt>();
+ synReturn->expression = resultVarExpr;
+ seqStmt->stmts.add(synReturn);
+ }
+
context->parentDecl->members.add(synFunc);
context->parentDecl->invalidateMemberDictionary();
addModifier(synFunc, m_astBuilder->create<SynthesizedModifier>());
- witnessTable->add(requirementDeclRef, RequirementWitness(makeDeclRef(synFunc)));
+ // If `This` is nested inside a generic, we need to form a complete declref type to the
+ // newly synthesized method here in order to fill into the witness table.
+ // This can be done by obtaining ThisTypeSubstitution from requirementDeclRef to get the
+ // generic substitution for outer generic parameters, and apply it here.
+ SubstitutionSet substSet;
+ if (auto thisTypeSusbt = findThisTypeSubstitution(
+ requirementDeclRef.substitutions,
+ as<InterfaceDecl>(requirementDeclRef.getDecl()->parentDecl)))
+ {
+ if (auto declRefType = as<DeclRefType>(thisTypeSusbt->witness->sub))
+ {
+ substSet = declRefType->declRef.substitutions;
+ }
+ }
+
+ witnessTable->add(requirementDeclRef, RequirementWitness(DeclRef<Decl>(synFunc, substSet)));
return true;
}
@@ -3801,7 +3942,10 @@ namespace Slang
// be required to implement all interface requirements,
// just with `abstract` methods that replicate things?
// (That's what C# does).
- for (auto inheritanceDecl : decl->getMembersOfType<InheritanceDecl>())
+
+ // Make a copy of inhertanceDecls firstsince `checkConformance` may modify decl->members.
+ auto inheritanceDecls = decl->getMembersOfType<InheritanceDecl>().toList();
+ for (auto inheritanceDecl : inheritanceDecls)
{
checkConformance(type, inheritanceDecl, decl);
}
@@ -5230,7 +5374,7 @@ namespace Slang
void SemanticsDeclHeaderVisitor::checkCallableDeclCommon(CallableDecl* decl)
{
- if (decl->findModifier<ForwardDifferentiableAttribute>())
+ if (decl->findModifier<DifferentiableAttribute>())
{
this->getShared()->getDiffTypeContext()->requireDifferentiableTypeDictionary();
}
@@ -6274,6 +6418,10 @@ namespace Slang
SemanticsDeclConformancesVisitor(shared).dispatch(decl);
break;
+ case DeclCheckState::TypesFullyResolved:
+ SemanticsDeclTypeResolutionVisitor(shared).dispatch(decl);
+ break;
+
case DeclCheckState::Checked:
SemanticsDeclBodyVisitor(shared).dispatch(decl);
break;
@@ -6325,4 +6473,40 @@ namespace Slang
return result;
}
+ Val* SemanticsDeclTypeResolutionVisitor::resolveVal(Val* val)
+ {
+ if (auto declRefType = as<DeclRefType>(val))
+ {
+ if (auto concreteType = _tryLookupConcreteAssociatedTypeFromThisTypeSubst(m_astBuilder, declRefType->declRef))
+ return as<Type>(concreteType);
+ for (auto subst = declRefType->declRef.substitutions.substitutions; subst; subst=subst->outer)
+ {
+ if (auto genericSubst = as<GenericSubstitution>(subst))
+ {
+ ShortList<Val*> newArgs;
+ for (auto& arg : genericSubst->getArgs())
+ {
+ arg = resolveVal(arg);
+ SLANG_RELEASE_ASSERT(arg);
+ }
+ }
+ }
+ }
+ else if (auto subtypeWitness = as<SubtypeWitness>(val))
+ {
+ auto sub = as<Type>(resolveVal(subtypeWitness->sub));
+ auto sup = as<Type>(resolveVal(subtypeWitness->sup));
+ if (sub && sup)
+ {
+ if (sub != subtypeWitness->sub || sup != subtypeWitness->sup)
+ {
+ auto newVal = tryGetSubtypeWitness(as<Type>(sub), as<Type>(sup));
+ if (newVal)
+ val = newVal;
+ }
+ }
+ }
+ return val;
+ }
+
}
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index fe37f5099..251849ede 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -1410,6 +1410,19 @@ namespace Slang
return m_astBuilder->getOrCreate<ConstantIntVal>(m_astBuilder->getBoolType(), value);
}
+ if (auto arrayLengthExpr = expr.as<GetArrayLengthExpr>())
+ {
+ if (arrayLengthExpr.getExpr()->arrayExpr && arrayLengthExpr.getExpr()->arrayExpr->type)
+ {
+ auto type = arrayLengthExpr.getExpr()->arrayExpr->type.type->substitute(m_astBuilder, expr.getSubsts());
+ if (auto arrayType = as<ArrayExpressionType>(type))
+ {
+ if (auto val = as<IntVal>(arrayType->arrayLength))
+ return val;
+ }
+ }
+ }
+
// it is possible that we are referring to a generic value param
if (auto declRefExpr = expr.as<DeclRefExpr>())
{
@@ -1871,14 +1884,42 @@ namespace Slang
arg = CheckTerm(arg);
}
- return CheckInvokeExprWithCheckedOperands(expr);
+ // If we are in a differentiable function, register differential witness tables involved in
+ // this call.
+ if (m_parentFunc && m_parentFunc->hasModifier<DifferentiableAttribute>())
+ {
+ for (auto& arg : expr->arguments)
+ {
+ maybeRegisterDifferentiableType(m_astBuilder, arg->type.type);
+ }
+ }
+
+ auto checkedExpr = CheckInvokeExprWithCheckedOperands(expr);
+
+ if (m_parentFunc && m_parentFunc->hasModifier<DifferentiableAttribute>())
+ {
+ if (auto checkedInvokeExpr = as<InvokeExpr>(checkedExpr))
+ {
+ // Register types for final resolved invoke arguments again.
+ for (auto& arg : expr->arguments)
+ {
+ maybeRegisterDifferentiableType(m_astBuilder, arg->type.type);
+ }
+ }
+ maybeRegisterDifferentiableType(m_astBuilder, checkedExpr->type.type);
+ }
+ return checkedExpr;
}
Expr* SemanticsExprVisitor::visitVarExpr(VarExpr *expr)
{
// If we've already resolved this expression, don't try again.
if (expr->declRef)
+ {
+ if (!expr->type)
+ expr->type = GetTypeForDeclRef(expr->declRef, expr->loc);
return expr;
+ }
expr->type = QualType(m_astBuilder->getErrorType());
auto lookupResult = lookUp(
@@ -1908,63 +1949,56 @@ namespace Slang
return expr;
}
- Type* SemanticsVisitor::_toDifferentialParamType(ASTBuilder* builder, Type* primalType)
+ Type* SemanticsVisitor::_toDifferentialParamType(Type* primalType)
{
// Check for type modifiers like 'out' and 'inout'. We need to differentiate the
// nested type.
//
if (auto primalOutType = as<OutType>(primalType))
{
- return builder->getOutType(_toDifferentialParamType(builder, primalOutType->getValueType()));
+ return m_astBuilder->getOutType(_toDifferentialParamType(primalOutType->getValueType()));
}
else if (auto primalInOutType = as<InOutType>(primalType))
{
- return builder->getInOutType(_toDifferentialParamType(builder, primalInOutType->getValueType()));
+ return m_astBuilder->getInOutType(_toDifferentialParamType(primalInOutType->getValueType()));
}
+ return getDifferentialPairType(primalType);
+ }
+ Type* SemanticsVisitor::getDifferentialPairType(Type* primalType)
+ {
// Get a reference to the builtin 'IDifferentiable' interface
- auto differentiableInterface = builder->getDifferentiableInterface();
+ auto differentiableInterface = m_astBuilder->getDifferentiableInterface();
+ auto conformanceWitness = as<Witness>(tryGetInterfaceConformanceWitness(primalType, differentiableInterface));
// Check if the provided type inherits from IDifferentiable.
// If not, return the original type.
- if (auto conformanceWitness = as<Witness>(tryGetInterfaceConformanceWitness(primalType, differentiableInterface)))
- return builder->getDifferentialPairType(primalType, conformanceWitness);
+ if (conformanceWitness)
+ return m_astBuilder->getDifferentialPairType(primalType, conformanceWitness);
else
return primalType;
-
}
- Type* SemanticsVisitor::_toJVPReturnType(ASTBuilder* builder, Type* primalType)
- {
- if (auto conformanceWitness =
- as<Witness>(tryGetInterfaceConformanceWitness(
- primalType,
- builder->getDifferentiableInterface())))
- return builder->getDifferentialPairType(primalType, conformanceWitness);
- else
- return primalType;
- }
-
- Type* SemanticsVisitor::processJVPFuncType(ASTBuilder* builder, FuncType* originalType)
+ Type* SemanticsVisitor::processJVPFuncType(FuncType* originalType)
{
// Resolve JVP type here.
// Note that this type checking needs to be in sync with
// the auto-generation logic in slang-ir-jvp-diff.cpp
- FuncType* jvpType = builder->create<FuncType>();
+ FuncType* jvpType = m_astBuilder->create<FuncType>();
// The JVP return type is float if primal return type is float
// void otherwise.
//
- jvpType->resultType = _toJVPReturnType(builder, originalType->getResultType());
+ jvpType->resultType = getDifferentialPairType(originalType->getResultType());
// No support for differentiating function that throw errors, for now.
- SLANG_ASSERT(originalType->errorType->equals(builder->getBottomType()));
+ SLANG_ASSERT(originalType->errorType->equals(m_astBuilder->getBottomType()));
jvpType->errorType = originalType->errorType;
for (UInt i = 0; i < originalType->getParamCount(); i++)
{
- if(auto jvpParamType = _toDifferentialParamType(builder, originalType->getParamType(i)))
+ if(auto jvpParamType = _toDifferentialParamType(originalType->getParamType(i)))
jvpType->paramTypes.add(jvpParamType);
}
@@ -1978,6 +2012,15 @@ namespace Slang
// Check/Resolve inner function declaration.
expr->baseFunction = CheckTerm(expr->baseFunction);
+ // Register parameter types.
+ if (auto funcType = as<FuncType>(expr->baseFunction->type.type))
+ {
+ for (UInt i = 0; i < funcType->getParamCount(); i++)
+ {
+ maybeRegisterDifferentiableType(m_astBuilder, funcType->getParamType(i));
+ }
+ }
+
// For now we only support using higher order expr as callee in an invoke expr.
// The actual type of the higher order function will be derived during resolve invoke.
expr->type = m_astBuilder->getBottomType();
@@ -1985,6 +2028,29 @@ namespace Slang
return expr;
}
+ Expr* SemanticsExprVisitor::visitGetArrayLengthExpr(GetArrayLengthExpr* expr)
+ {
+ expr->arrayExpr = CheckTerm(expr->arrayExpr);
+ if (auto arrType = as<ArrayExpressionType>(expr->arrayExpr->type))
+ {
+ expr->type = m_astBuilder->getIntType();
+ if (!arrType->arrayLength)
+ {
+ getSink()->diagnose(expr, Diagnostics::invalidArraySize);
+ }
+ }
+ else
+ {
+ if (!as<ErrorType>(expr->arrayExpr->type))
+ {
+ getSink()->diagnose(
+ expr, Diagnostics::typeMismatch, "array", expr->arrayExpr->type);
+ }
+ expr->type = m_astBuilder->getErrorType();
+ }
+ return expr;
+ }
+
Expr* SemanticsExprVisitor::visitTypeCastExpr(TypeCastExpr * expr)
{
// Check the term we are applying first
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index 821f785f5..33455e42d 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -794,15 +794,12 @@ namespace Slang
bool shouldSkipChecking(Decl* decl, DeclCheckState state);
// Auto-diff convenience functions for translating primal types to differential types.
- Type* _toDifferentialParamType(ASTBuilder* builder, Type* primalType);
+ Type* _toDifferentialParamType(Type* primalType);
+
+ Type* getDifferentialPairType(Type* primalType);
- // Translate a return type to the return type of a forward-mode differentiated
- // function.
- //
- Type* _toJVPReturnType(ASTBuilder* builder, Type* primalType);
-
// Convert a function's original type to it's JVP type.
- Type* processJVPFuncType(ASTBuilder* builder, FuncType* originalType);
+ Type* processJVPFuncType(FuncType* originalType);
// Check and register a type if it is differentiable.
void maybeRegisterDifferentiableType(ASTBuilder* builder, Type* type);
@@ -1038,6 +1035,11 @@ namespace Slang
DeclRef<GenericDecl> requirementGenDecl,
RefPtr<WitnessTable> witnessTable);
+ bool doesTypeSatisfyAssociatedTypeConstraintRequirement(
+ Type* satisfyingType,
+ DeclRef<AssocTypeDecl> requiredAssociatedTypeDeclRef,
+ RefPtr<WitnessTable> witnessTable);
+
bool doesTypeSatisfyAssociatedTypeRequirement(
Type* satisfyingType,
DeclRef<AssocTypeDecl> requiredAssociatedTypeDeclRef,
@@ -1124,7 +1126,7 @@ namespace Slang
/// Otherwise, returns `false`.
bool trySynthesizeDifferentialAssociatedTypeRequirementWitness(
ConformanceCheckingContext* context,
- DeclRef<Decl> requirementDeclRef,
+ DeclRef<AssocTypeDecl> requirementDeclRef,
RefPtr<WitnessTable> witnessTable);
/// Registers a type as differentiable in the currrent semantic context, if the declaration represents
@@ -1989,6 +1991,8 @@ namespace Slang
Expr* visitForwardDifferentiateExpr(ForwardDifferentiateExpr* expr);
+ Expr* visitGetArrayLengthExpr(GetArrayLengthExpr* expr);
+
/// Perform semantic checking on a `modifier` that is being applied to the given `type`
Val* checkTypeModifier(Modifier* modifier, Type* type);
diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp
index ef067c06c..42fab94a6 100644
--- a/source/slang/slang-check-overload.cpp
+++ b/source/slang/slang-check-overload.cpp
@@ -1559,7 +1559,7 @@ namespace Slang
OverloadCandidate candidate;
candidate.flavor = OverloadCandidate::Flavor::Expr;
- candidate.funcType = as<FuncType>(processJVPFuncType(this->getASTBuilder(), origFuncType));
+ candidate.funcType = as<FuncType>(processJVPFuncType(origFuncType));
candidate.resultType = candidate.funcType->getResultType();
candidate.item = LookupResultItem(baseFuncDeclRef);
@@ -1576,7 +1576,6 @@ namespace Slang
OverloadCandidate candidate;
candidate.flavor = OverloadCandidate::Flavor::Expr;
candidate.funcType = as<FuncType>(processJVPFuncType(
- this->getASTBuilder(),
as<FuncType>(GetTypeForDeclRef(item.declRef, item.declRef.decl->loc))));
candidate.resultType = candidate.funcType->getResultType();
candidate.item = LookupResultItem(item.declRef);
@@ -1606,7 +1605,7 @@ namespace Slang
auto funcType = getFuncType(this->getASTBuilder(), unspecializedInnerRef.as<CallableDecl>());
// Process func type to generate JVP func type.
- auto jvpFuncType = as<FuncType>(processJVPFuncType(this->getASTBuilder(), funcType));
+ auto jvpFuncType = as<FuncType>(processJVPFuncType(funcType));
// Extract parameter list from processed type.
List<Type*> paramTypes;
@@ -1631,7 +1630,6 @@ namespace Slang
// This could potentially be a declRef.substitute(jvpFuncType)
//
candidate.funcType = as<FuncType>(processJVPFuncType(
- this->getASTBuilder(),
getFuncType(this->getASTBuilder(), innerRef.as<CallableDecl>())));
candidate.resultType = candidate.funcType->getResultType();
diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp
index 73818dbb1..3d02d4fc0 100644
--- a/source/slang/slang-ir-diff-jvp.cpp
+++ b/source/slang/slang-ir-diff-jvp.cpp
@@ -205,7 +205,7 @@ struct DifferentiableTypeConformanceContext
{
if (as<IRModuleInst>(inst) && differentiableInterfaceType)
{
- // Assume for now that IDifferentiable has exactly three fields.
+ // Assume for now that IDifferentiable has exactly four fields.
SLANG_ASSERT(differentiableInterfaceType->getOperandCount() == 4);
if (auto entry = as<IRInterfaceRequirementEntry>(differentiableInterfaceType->getOperand(index)))
return as<IRStructKey>(entry->getRequirementKey());
@@ -462,45 +462,6 @@ struct DifferentialPairTypeBuilder
}
}
- void _createGenericDiffPairType(IRBuilder* builder)
- {
- // Insert directly at top level (skip any generic scopes etc.)
- auto insertLoc = builder->getInsertLoc();
- builder->setInsertInto(builder->getModule()->getModuleInst());
-
- // Make a generic version of the pair struct.
- auto irGeneric = builder->emitGeneric();
- irGeneric->setFullType(builder->getTypeKind());
- builder->setInsertInto(irGeneric);
-
- generatedTypeList.add(irGeneric);
-
- auto irBlock = builder->emitBlock();
- builder->setInsertInto(irBlock);
-
- auto pTypeParam = builder->emitParam(builder->getTypeType());
- builder->addNameHintDecoration(pTypeParam, UnownedTerminatedStringSlice("pT"));
-
- auto dTypeParam = builder->emitParam(builder->getTypeType());
- builder->addNameHintDecoration(dTypeParam, UnownedTerminatedStringSlice("dT"));
-
- auto irStructType = builder->createStructType();
- builder->emitReturn(irStructType);
-
- auto primalKey = _getOrCreatePrimalStructKey(builder);
- builder->addNameHintDecoration(primalKey, UnownedTerminatedStringSlice("primal"));
- builder->createStructField(irStructType, primalKey, (IRType*) pTypeParam);
-
- auto diffKey = _getOrCreateDiffStructKey(builder);
- builder->addNameHintDecoration(diffKey, UnownedTerminatedStringSlice("differential"));
- builder->createStructField(irStructType, diffKey, (IRType*) dTypeParam);
-
- // Reset cursor when done.
- builder->setInsertLoc(insertLoc);
-
- this->genericDiffPairType = irGeneric;
- }
-
IRStructKey* _getOrCreateDiffStructKey(IRBuilder* builder)
{
if (!this->globalDiffKey)
@@ -535,17 +496,6 @@ struct DifferentialPairTypeBuilder
return this->globalPrimalKey;
}
- IRInst* _getOrCreateGenericDiffPairType(IRBuilder* builder)
- {
- if (!this->genericDiffPairType)
- {
- _createGenericDiffPairType(builder);
- }
-
- SLANG_ASSERT(this->genericDiffPairType);
- return this->genericDiffPairType;
- }
-
IRInst* _createDiffPairType(IRBuilder* builder, IRType* origBaseType)
{
if (auto diffBaseType = diffConformanceContext->getDifferentialForType(builder, origBaseType))
@@ -1383,22 +1333,10 @@ struct JVPTranscriber
}
else
{
- // We special case a few non-differentiable types that sometimes appear in places
- // where we're forced to provide a differential zero value. For instance,
- // float3(float, float, int) is accepted by the compiler, but is tricky in the context
- // of differentiation since int is non-differentiable, and should be cast to float first.
- // In the absence of such casts, this piece of code generates appropriate zero values.
- //
- switch (primalType->getOp())
- {
- case kIROp_IntType:
- return builder->getIntValue(primalType, 0);
- default:
- getSink()->diagnose(primalType->sourceLoc,
- Diagnostics::internalCompilerError,
- "could not generate zero value for given type");
- return nullptr;
- }
+ getSink()->diagnose(primalType->sourceLoc,
+ Diagnostics::internalCompilerError,
+ "could not generate zero value for given type");
+ return nullptr;
}
}
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index 1d1db14f9..61aa28bbe 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -60,6 +60,7 @@ INST(Nop, nop, 0, 0)
INST(OptionalType, Optional, 1, 0)
INST(DifferentialPairType, DiffPair, 1, 0)
+ INST(DifferentialBottomType, DiffBottomType, 0, 0)
/* BindExistentialsTypeBase */
@@ -277,7 +278,6 @@ INST(lookup_interface_method, lookup_interface_method, 2, 0)
INST(GetSequentialID, GetSequentialID, 1, 0)
INST(lookup_witness_table, lookup_witness_table, 2, 0)
INST(BindGlobalGenericParam, bind_global_generic_param, 2, 0)
-
INST(Construct, construct, 0, 0)
INST(AllocObj, allocObj, 0, 0)
@@ -297,6 +297,7 @@ INST(GetOptionalValue, getOptionalValue, 1, 0)
INST(OptionalHasValue, optionalHasValue, 1, 0)
INST(MakeOptionalValue, makeOptionalValue, 1, 0)
INST(MakeOptionalNone, makeOptionalNone, 1, 0)
+INST(DifferentialBottomValue, differentialBottomVal, 0, 0)
INST(Call, call, 1, 0)
INST(RTTIObject, rtti_object, 0, 0)
@@ -759,6 +760,7 @@ INST(Reinterpret, reinterpret, 1, 0)
INST(CastPtrToBool, CastPtrToBool, 1, 0)
INST(IsType, IsType, 3, 0)
INST(ForwardDifferentiate, ForwardDifferentiate, 1, 0)
+INST(DifferentialEqualityTypeCast, DifferentialEqualityTypeCast, 1, 0)
// Converts other resources (such as ByteAddressBuffer) to the equivalent StructuredBuffer
INST(GetEquivalentStructuredBuffer, getEquivalentStructuredBuffer, 1, 0)
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 3a59eb6c9..382f7be5e 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -3629,6 +3629,7 @@ namespace Slang
IRInst* IRBuilder::findDifferentiableTypeEntry(IRInst* irType, IRInst* scope)
{
+ IRInst* foundResult = nullptr;
for (auto child = scope->getFirstChild(); child; child = child->getNextInst())
{
if (child->getOp() == kIROp_DifferentiableTypeDictionary)
@@ -3640,13 +3641,20 @@ namespace Slang
if (irType == entryType)
{
- return entryConformanceWitness;
+ foundResult = entryConformanceWitness;
+ // If the found witness table is not a trivial one (i.e. DifferentialBottom:IDifferential),
+ // return immediately. Otherwise, continue the search to see if we can find a better one.
+ if (auto witness = as<IRWitnessTable>(foundResult))
+ {
+ if (witness->getConcreteType()->getOp() != kIROp_DifferentialBottomType)
+ return foundResult;
+ }
}
}
}
}
- return nullptr;
+ return foundResult;
}
IRInst* IRBuilder::findDifferentiableTypeEntry(IRInst* irType)
@@ -6282,6 +6290,8 @@ namespace Slang
case kIROp_MakeOptionalNone:
case kIROp_OptionalHasValue:
case kIROp_GetOptionalValue:
+ case kIROp_MakeTuple:
+ case kIROp_GetTupleElement:
case kIROp_Load: // We are ignoring the possibility of loads from bad addresses, or `volatile` loads
case kIROp_ImageSubscript:
case kIROp_FieldExtract:
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index ae0590105..8f00253f5 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -1366,12 +1366,29 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
// produce transitive witnesses in shapes that will cuase us
// problems here.
//
- IRInst* requirementKey = lowerSimpleVal(context, val->midToSup);
+ IRInst* midToSup = lowerSimpleVal(context, val->midToSup);
+
+ if (!baseWitnessTable)
+ {
+ // If we don't have a valid baseWitnessTable,
+ // we are in a situation that `subToMid` is a `DifferentialBottomSubtypeWitness`
+ // that applies for all non-differentiable types.
+ // In this case `midToSup` will give us the `DifferentialBottom:IDifferentiable`
+ // witness table and we can just use that as the final result of
+ // this transitive witness.
+ SLANG_RELEASE_ASSERT(midToSup && as<IRWitnessTableType>(midToSup->getDataType()));
+ return LoweredValInfo::simple(midToSup);
+ }
return LoweredValInfo::simple(getBuilder()->emitLookupInterfaceMethodInst(
getBuilder()->getWitnessTableType(lowerType(context, val->sup)),
baseWitnessTable,
- requirementKey));
+ midToSup));
+ }
+
+ LoweredValInfo visitDifferentialBottomSubtypeWitness(DifferentialBottomSubtypeWitness*)
+ {
+ return LoweredValInfo();
}
LoweredValInfo visitTaggedUnionSubtypeWitness(
@@ -3053,6 +3070,15 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
baseVal.val));
}
+ LoweredValInfo visitGetArrayLengthExpr(GetArrayLengthExpr* expr)
+ {
+ auto baseVal = lowerSubExpr(expr->arrayExpr);
+ auto type = lowerType(context, expr->arrayExpr->type);
+ auto arrayType = as<IRArrayType>(type);
+ SLANG_ASSERT(arrayType);
+ return LoweredValInfo::simple(arrayType->getElementCount());
+ }
+
LoweredValInfo visitOverloadedExpr(OverloadedExpr* /*expr*/)
{
SLANG_UNEXPECTED("overloaded expressions should not occur in checked AST");
@@ -5857,7 +5883,9 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
// add an entry to the context.
//
if (irWitness && !getBuilder()->findDifferentiableTypeEntry(irType))
+ {
getBuilder()->addDifferentiableTypeEntry(irType, irWitness);
+ }
}
else if (auto importEntry = as<DifferentiableTypeDictionaryImportItem>(member))
{
@@ -6777,7 +6805,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
IRInterfaceType* irInterface = subBuilder->createInterfaceType(operandCount, nullptr);
// Add `irInterface` to decl mapping now to prevent cyclic lowering.
- setValue(subContext, decl, LoweredValInfo::simple(irInterface));
+ setValue(context, decl, LoweredValInfo::simple(irInterface));
// Setup subContext for proper lowering `ThisType`, associated types and
// the interface decl's self reference.
@@ -7084,6 +7112,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
void lowerDerivativeMemberModifier(IRInst* inst, DerivativeMemberAttribute* derivativeMember)
{
+ ensureDecl(context, derivativeMember->memberDeclRef->declRef.getDecl()->parentDecl);
auto key = lowerRValueExpr(context, derivativeMember->memberDeclRef).val;
SLANG_RELEASE_ASSERT(as<IRStructKey>(key));
auto builder = getBuilder();
diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp
index 93f2fcdcb..980a1d0bc 100644
--- a/source/slang/slang-parser.cpp
+++ b/source/slang/slang-parser.cpp
@@ -132,8 +132,8 @@ namespace Slang
Scope* newScope = astBuilder->create<Scope>();
newScope->containerDecl = containerDecl;
newScope->parent = currentScope;
-
currentScope = newScope;
+ containerDecl->ownedScope = newScope;
}
void pushScopeAndSetParent(ContainerDecl* containerDecl)
diff --git a/source/slang/slang-syntax.cpp b/source/slang/slang-syntax.cpp
index c779b4510..8cd443438 100644
--- a/source/slang/slang-syntax.cpp
+++ b/source/slang/slang-syntax.cpp
@@ -319,7 +319,31 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt
}
}
}
-
+ else if (auto transitiveTypeWitness = as<TransitiveSubtypeWitness>(subtypeWitness))
+ {
+ // Hard code witness entry that `T.Differential = DifferentialBottom` for `T` that
+ // coerce to `DifferentialBottom`.
+ if (astBuilder->getDifferentialBottomType()->equals(transitiveTypeWitness->subToMid->sup))
+ {
+ if (auto builtinAttr = requirementKey->findModifier<BuiltinRequirementAttribute>())
+ {
+ if (builtinAttr->kind == BuiltinRequirementKind::DifferentialType)
+ {
+ return RequirementWitness(astBuilder->getDifferentialBottomType());
+ }
+ }
+ }
+ }
+ else if (auto extractFromConjunctionTypeWitness = as<ExtractFromConjunctionSubtypeWitness>(subtypeWitness))
+ {
+ if (auto conjunctionTypeWitness = as<ConjunctionSubtypeWitness>(extractFromConjunctionTypeWitness->conjunctionWitness))
+ {
+ if (extractFromConjunctionTypeWitness->indexInConjunction == 0)
+ return tryLookUpRequirementWitness(astBuilder, as<SubtypeWitness>(conjunctionTypeWitness->leftWitness), requirementKey);
+ else
+ return tryLookUpRequirementWitness(astBuilder, as<SubtypeWitness>(conjunctionTypeWitness->rightWitness), requirementKey);
+ }
+ }
// TODO: should handle the transitive case here too
return RequirementWitness();
@@ -1140,7 +1164,46 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt
return nullptr;
}
- //
+ Val* _tryLookupConcreteAssociatedTypeFromThisTypeSubst(ASTBuilder* builder, DeclRef<Decl> declRef)
+ {
+ auto substDeclRef = declRef.as<AssocTypeDecl>();
+ if (!substDeclRef)
+ return nullptr;
+
+ auto substAssocTypeDecl = substDeclRef.getDecl();
+
+ for (auto s = substDeclRef.substitutions.substitutions; s; s = s->outer)
+ {
+ auto thisSubst = as<ThisTypeSubstitution>(s);
+ if (!thisSubst)
+ continue;
+
+ if (auto interfaceDecl = as<InterfaceDecl>(substAssocTypeDecl->parentDecl))
+ {
+ if (thisSubst->interfaceDecl == interfaceDecl)
+ {
+ // We need to look up the declaration that satisfies
+ // the requirement named by the associated type.
+ Decl* requirementKey = substAssocTypeDecl;
+ RequirementWitness requirementWitness = tryLookUpRequirementWitness(builder, thisSubst->witness, requirementKey);
+ switch (requirementWitness.getFlavor())
+ {
+ default:
+ // No usable value was found, so there is nothing we can do.
+ break;
+
+ case RequirementWitness::Flavor::val:
+ {
+ auto satisfyingVal = requirementWitness.getVal();
+ return satisfyingVal;
+ }
+ break;
+ }
+ }
+ }
+ }
+ return nullptr;
+ }
String DeclRefBase::toString() const
{
diff --git a/source/slang/slang.natvis b/source/slang/slang.natvis
index ee868be62..b38bd358f 100644
--- a/source/slang/slang.natvis
+++ b/source/slang/slang.natvis
@@ -233,7 +233,7 @@
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::LetExpr">(Slang::LetExpr*)&amp;astNodeType</ExpandedItem>
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ExtractExistentialValueExpr">(Slang::ExtractExistentialValueExpr*)&amp;astNodeType</ExpandedItem>
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::OpenRefExpr">(Slang::OpenRefExpr*)&amp;astNodeType</ExpandedItem>
- <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::JVPDifferentiateExpr">(Slang::JVPDifferentiateExpr*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ForwardDifferentiateExpr">(Slang::ForwardDifferentiateExpr*)&amp;astNodeType</ExpandedItem>
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::TaggedUnionTypeExpr">(Slang::TaggedUnionTypeExpr*)&amp;astNodeType</ExpandedItem>
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ThisTypeExpr">(Slang::ThisTypeExpr*)&amp;astNodeType</ExpandedItem>
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::AndTypeExpr">(Slang::AndTypeExpr*)&amp;astNodeType</ExpandedItem>
@@ -384,6 +384,7 @@
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ErrorType">(Slang::ErrorType*)&amp;astNodeType</ExpandedItem>
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::BottomType">(Slang::BottomType*)&amp;astNodeType</ExpandedItem>
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::DeclRefType">(Slang::DeclRefType*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::DifferentialPairType">(Slang::DeclRefType*)&amp;astNodeType</ExpandedItem>
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ArithmeticExpressionType">(Slang::ArithmeticExpressionType*)&amp;astNodeType</ExpandedItem>
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::BasicExpressionType">(Slang::BasicExpressionType*)&amp;astNodeType</ExpandedItem>
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::VectorExpressionType">(Slang::VectorExpressionType*)&amp;astNodeType</ExpandedItem>
@@ -456,7 +457,7 @@
</Expand>
</Type>
- <Type Name="Slang::Substitutions">
+ <Type Name="Slang::Substitutions" Inheritable="false">
<DisplayString>{astNodeType}</DisplayString>
<Expand>
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::GenericSubstitution">(Slang::GenericSubstitution*)&amp;astNodeType</ExpandedItem>
@@ -469,7 +470,7 @@
<LinkedListItems>
<HeadPointer>substitutions</HeadPointer>
<NextPointer>outer</NextPointer>
- <ValueNode>this</ValueNode>
+ <ValueNode>(Slang::Substitutions*)this</ValueNode>
</LinkedListItems>
</Expand>
</Type>
@@ -487,6 +488,79 @@
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::GenericParamIntVal">(Slang::GenericParamIntVal*)&amp;astNodeType</ExpandedItem>
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::DeclaredSubtypeWitness">(Slang::DeclaredSubtypeWitness*)&amp;astNodeType</ExpandedItem>
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::TransitiveSubtypeWitness">(Slang::TransitiveSubtypeWitness*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::OverloadGroupType">(Slang::OverloadGroupType*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::InitializerListType">(Slang::InitializerListType*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ErrorType">(Slang::ErrorType*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::BottomType">(Slang::BottomType*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::DeclRefType">(Slang::DeclRefType*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::DifferentialPairType">(Slang::DeclRefType*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ArithmeticExpressionType">(Slang::ArithmeticExpressionType*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::BasicExpressionType">(Slang::BasicExpressionType*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::VectorExpressionType">(Slang::VectorExpressionType*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::MatrixExpressionType">(Slang::MatrixExpressionType*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::BuiltinType">(Slang::BuiltinType*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::FeedbackType">(Slang::FeedbackType*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ResourceType">(Slang::ResourceType*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::TextureTypeBase">(Slang::TextureTypeBase*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::TextureType">(Slang::TextureType*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::TextureSamplerType">(Slang::TextureSamplerType*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::GLSLImageType">(Slang::GLSLImageType*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::SamplerStateType">(Slang::SamplerStateType*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::BuiltinGenericType">(Slang::BuiltinGenericType*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::PointerLikeType">(Slang::PointerLikeType*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ParameterGroupType">(Slang::ParameterGroupType*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::UniformParameterGroupType">(Slang::UniformParameterGroupType*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ConstantBufferType">(Slang::ConstantBufferType*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::TextureBufferType">(Slang::TextureBufferType*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::GLSLShaderStorageBufferType">(Slang::GLSLShaderStorageBufferType*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ParameterBlockType">(Slang::ParameterBlockType*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::VaryingParameterGroupType">(Slang::VaryingParameterGroupType*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::GLSLInputParameterGroupType">(Slang::GLSLInputParameterGroupType*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::GLSLOutputParameterGroupType">(Slang::GLSLOutputParameterGroupType*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::HLSLStructuredBufferTypeBase">(Slang::HLSLStructuredBufferTypeBase*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::HLSLStructuredBufferType">(Slang::HLSLStructuredBufferType*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::HLSLRWStructuredBufferType">(Slang::HLSLRWStructuredBufferType*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::HLSLRasterizerOrderedStructuredBufferType">(Slang::HLSLRasterizerOrderedStructuredBufferType*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::HLSLAppendStructuredBufferType">(Slang::HLSLAppendStructuredBufferType*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::HLSLConsumeStructuredBufferType">(Slang::HLSLConsumeStructuredBufferType*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::HLSLStreamOutputType">(Slang::HLSLStreamOutputType*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::HLSLPointStreamType">(Slang::HLSLPointStreamType*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::HLSLLineStreamType">(Slang::HLSLLineStreamType*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::HLSLTriangleStreamType">(Slang::HLSLTriangleStreamType*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::UntypedBufferResourceType">(Slang::UntypedBufferResourceType*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::HLSLByteAddressBufferType">(Slang::HLSLByteAddressBufferType*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::HLSLRWByteAddressBufferType">(Slang::HLSLRWByteAddressBufferType*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::HLSLRasterizerOrderedByteAddressBufferType">(Slang::HLSLRasterizerOrderedByteAddressBufferType*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::RaytracingAccelerationStructureType">(Slang::RaytracingAccelerationStructureType*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::HLSLPatchType">(Slang::HLSLPatchType*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::HLSLInputPatchType">(Slang::HLSLInputPatchType*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::HLSLOutputPatchType">(Slang::HLSLOutputPatchType*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::GLSLInputAttachmentType">(Slang::GLSLInputAttachmentType*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::StringTypeBase">(Slang::StringTypeBase*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::StringType">(Slang::StringType*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::NativeStringType">(Slang::NativeStringType*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::DynamicType">(Slang::DynamicType*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::EnumTypeType">(Slang::EnumTypeType*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::PtrTypeBase">(Slang::PtrTypeBase*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::PtrType">(Slang::PtrType*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ParamDirectionType">(Slang::ParamDirectionType*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::OutTypeBase">(Slang::OutTypeBase*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::OutType">(Slang::OutType*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::InOutType">(Slang::InOutType*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::RefType">(Slang::RefType*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::NullPtrType">(Slang::NullPtrType*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ArrayExpressionType">(Slang::ArrayExpressionType*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::TypeType">(Slang::TypeType*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::NamedExpressionType">(Slang::NamedExpressionType*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::FuncType">(Slang::FuncType*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::GenericDeclRefType">(Slang::GenericDeclRefType*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::NamespaceType">(Slang::NamespaceType*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ExtractExistentialType">(Slang::ExtractExistentialType*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::TaggedUnionType">(Slang::TaggedUnionType*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ExistentialSpecializedType">(Slang::ExistentialSpecializedType*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ThisType">(Slang::ThisType*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::AndType">(Slang::AndType*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ModifiedType">(Slang::ModifiedType*)&amp;astNodeType</ExpandedItem>
</Expand>
</Type>
</AutoVisualizer> \ No newline at end of file