summaryrefslogtreecommitdiffstats
path: root/source/slang
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-02-24 13:22:07 -0800
committerGitHub <noreply@github.com>2023-02-24 13:22:07 -0800
commit91694dacdb8d3ab7dd9783d7c0c43629bf11f578 (patch)
tree19eb8db8845b22c379ebc3f114d610c1e401bd9d /source/slang
parentbd6306cdaa4a49344658bd026721b6532e103d09 (diff)
Fix differential type registration through non-differentiable type. (#2677)
* Fix differential type registration through non-differentiable type. * More fix. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang')
-rw-r--r--source/slang/slang-ast-base.h35
-rw-r--r--source/slang/slang-ast-builder.h34
-rw-r--r--source/slang/slang-ast-modifier.h2
-rw-r--r--source/slang/slang-check-decl.cpp1
-rw-r--r--source/slang/slang-check-expr.cpp98
-rw-r--r--source/slang/slang-check-impl.h5
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp9
-rw-r--r--source/slang/slang-ir-autodiff-transcriber-base.cpp9
8 files changed, 77 insertions, 116 deletions
diff --git a/source/slang/slang-ast-base.h b/source/slang/slang-ast-base.h
index 627a56152..1b8b221ef 100644
--- a/source/slang/slang-ast-base.h
+++ b/source/slang/slang-ast-base.h
@@ -150,6 +150,41 @@ class Val : public NodeBase
HashCode _getHashCodeOverride();
};
+struct ValSet
+{
+ struct ValItem
+ {
+ Val* val = nullptr;
+ ValItem() = default;
+ ValItem(Val* v) : val(v) {}
+
+ HashCode getHashCode()
+ {
+ return val ? val->getHashCode() : 0;
+ }
+ bool operator==(ValItem other)
+ {
+ if (val == other.val)
+ return true;
+ if (val)
+ return val->equalsVal(other.val);
+ else if (other.val)
+ return other.val->equalsVal(val);
+ return false;
+ }
+ };
+ HashSet<ValItem> set;
+ bool add(Val* val)
+ {
+ return set.Add(ValItem(val));
+ }
+ bool contains(Val* val)
+ {
+ return set.Contains(ValItem(val));
+ }
+};
+
+
SLANG_FORCE_INLINE StringBuilder& operator<<(StringBuilder& io, Val* val) { SLANG_ASSERT(val); val->toText(io); return io; }
/// Given a `value` that refers to a `param` of some generic, attempt to apply
diff --git a/source/slang/slang-ast-builder.h b/source/slang/slang-ast-builder.h
index d44a15813..c5e9e5429 100644
--- a/source/slang/slang-ast-builder.h
+++ b/source/slang/slang-ast-builder.h
@@ -424,40 +424,6 @@ protected:
};
-struct ValSet
-{
- struct ValItem
- {
- Val* val = nullptr;
- ValItem() = default;
- ValItem(Val* v) : val(v) {}
-
- HashCode getHashCode()
- {
- return val ? val->getHashCode() : 0;
- }
- bool operator==(ValItem other)
- {
- if (val == other.val)
- return true;
- if (val)
- return val->equalsVal(other.val);
- else if (other.val)
- return other.val->equalsVal(val);
- return false;
- }
- };
- HashSet<ValItem> set;
- bool add(Val* val)
- {
- return set.Add(ValItem(val));
- }
- bool contains(Val* val)
- {
- return set.Contains(ValItem(val));
- }
-};
-
} // namespace Slang
#endif
diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h
index 6ac464784..7dd0819d8 100644
--- a/source/slang/slang-ast-modifier.h
+++ b/source/slang/slang-ast-modifier.h
@@ -1045,6 +1045,8 @@ class DifferentiableAttribute : public Attribute
/// Mapping from types to subtype witnesses for conformance to IDifferentiable.
OrderedDictionary<DeclRefBase, SubtypeWitness*> m_mapTypeToIDifferentiableWitness;
+
+ SLANG_UNREFLECTED ValSet m_typeRegistrationWorkingSet;
};
class DllImportAttribute : public Attribute
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 381efa2c7..142842e12 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -5037,7 +5037,6 @@ namespace Slang
auto thisType = calcThisType(parentDeclRef);
maybeRegisterDifferentiableType(m_astBuilder, thisType);
}
- completeDifferentiableTypeDictionary();
m_parentDifferentiableAttr = oldAttr;
}
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 3567e2593..2803b5959 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -923,7 +923,7 @@ namespace Slang
return result;
}
- void SemanticsVisitor::registerDifferentiableType(DeclRefType* type, SubtypeWitness* witness)
+ void SemanticsVisitor::addDifferentiableTypeToDiffTypeRegistry(DeclRefType* type, SubtypeWitness* witness)
{
SLANG_RELEASE_ASSERT(m_parentDifferentiableAttr);
if (witness)
@@ -944,49 +944,23 @@ namespace Slang
return;
}
- // Check for special cases such as PtrTypeBase<T> or Array<T>
- // This could potentially be handled later by simply defining extensions
- // for Ptr<T:IDifferentiable> etc..
- //
- if (auto ptrType = as<PtrTypeBase>(type))
- {
- maybeRegisterDifferentiableType(builder, ptrType->getValueType());
- return;
- }
-
- if (auto arrayType = as<ArrayExpressionType>(type))
- {
- maybeRegisterDifferentiableType(builder, arrayType->getElementType());
- // Fall through to register the array type itself.
- }
-
- if (auto declRefType = as<DeclRefType>(type))
- {
- if (auto subtypeWitness = as<SubtypeWitness>(
- tryGetInterfaceConformanceWitness(type, getASTBuilder()->getDifferentiableInterface())))
- {
- registerDifferentiableType((DeclRefType*)type, subtypeWitness);
- }
- return;
- }
+ maybeRegisterDifferentiableTypeImplRecursive(builder, type);
}
- void SemanticsVisitor::maybeRegisterDifferentiableTypeRecursive(ASTBuilder* builder, Type* type, ValSet& workingSet)
+ void SemanticsVisitor::maybeRegisterDifferentiableTypeImplRecursive(ASTBuilder* builder, Type* type)
{
- if (workingSet.contains(type))
+ // Recursively visit the tree of type and register all differentiable types along the way.
+
+ if (as<TypeType>(type))
return;
-
- if (!builder->isDifferentiableInterfaceAvailable())
- {
+ if (!type)
return;
- }
- if (!m_parentDifferentiableAttr)
- {
+ // Have we already registered this type? If so we can exit now.
+ if (m_parentDifferentiableAttr->m_typeRegistrationWorkingSet.contains(type))
return;
- }
- workingSet.add(type);
+ m_parentDifferentiableAttr->m_typeRegistrationWorkingSet.add(type);
// Check for special cases such as PtrTypeBase<T> or Array<T>
// This could potentially be handled later by simply defining extensions
@@ -994,13 +968,13 @@ namespace Slang
//
if (auto ptrType = as<PtrTypeBase>(type))
{
- maybeRegisterDifferentiableTypeRecursive(builder, ptrType->getValueType(), workingSet);
+ maybeRegisterDifferentiableTypeImplRecursive(builder, ptrType->getValueType());
return;
}
if (auto arrayType = as<ArrayExpressionType>(type))
{
- maybeRegisterDifferentiableTypeRecursive(builder, arrayType->getElementType(), workingSet);
+ maybeRegisterDifferentiableTypeImplRecursive(builder, arrayType->getElementType());
// Fall through to register the array type itself.
}
@@ -1009,20 +983,20 @@ namespace Slang
if (auto subtypeWitness = as<SubtypeWitness>(
tryGetInterfaceConformanceWitness(type, getASTBuilder()->getDifferentiableInterface())))
{
- registerDifferentiableType((DeclRefType*)type, subtypeWitness);
- if (auto aggTypeDeclRef = declRefType->declRef.as<AggTypeDecl>())
- {
- foreachDirectOrExtensionMemberOfType<InheritanceDecl>(this, aggTypeDeclRef, [&](DeclRef<InheritanceDecl> member)
- {
- auto subType = m_astBuilder->getOrCreateDeclRefType(member.getDecl(), nullptr);
- maybeRegisterDifferentiableTypeRecursive(m_astBuilder, subType, workingSet);
- });
- foreachDirectOrExtensionMemberOfType<VarDeclBase>(this, aggTypeDeclRef, [&](DeclRef<VarDeclBase> member)
- {
- auto fieldType = getType(m_astBuilder, member);
- maybeRegisterDifferentiableTypeRecursive(m_astBuilder, fieldType, workingSet);
- });
- }
+ addDifferentiableTypeToDiffTypeRegistry((DeclRefType*)type, subtypeWitness);
+ }
+ if (auto aggTypeDeclRef = declRefType->declRef.as<AggTypeDecl>())
+ {
+ foreachDirectOrExtensionMemberOfType<InheritanceDecl>(this, aggTypeDeclRef, [&](DeclRef<InheritanceDecl> member)
+ {
+ auto subType = m_astBuilder->getOrCreateDeclRefType(member.getDecl(), nullptr);
+ maybeRegisterDifferentiableTypeImplRecursive(m_astBuilder, subType);
+ });
+ foreachDirectOrExtensionMemberOfType<VarDeclBase>(this, aggTypeDeclRef, [&](DeclRef<VarDeclBase> member)
+ {
+ auto fieldType = getType(m_astBuilder, member);
+ maybeRegisterDifferentiableTypeImplRecursive(m_astBuilder, fieldType);
+ });
}
for (auto subst = declRefType->declRef.substitutions.substitutions; subst; subst = subst->outer)
{
@@ -1032,35 +1006,19 @@ namespace Slang
{
if (auto typeArg = as<Type>(arg))
{
- maybeRegisterDifferentiableTypeRecursive(m_astBuilder, typeArg, workingSet);
+ maybeRegisterDifferentiableTypeImplRecursive(m_astBuilder, typeArg);
}
}
}
else if (auto thisSubst = as<ThisTypeSubstitution>(subst))
{
- maybeRegisterDifferentiableTypeRecursive(m_astBuilder, thisSubst->witness->sub, workingSet);
+ maybeRegisterDifferentiableTypeImplRecursive(m_astBuilder, thisSubst->witness->sub);
}
}
return;
}
}
- void SemanticsVisitor::completeDifferentiableTypeDictionary()
- {
- ValSet workingSet;
- for (auto type : m_parentDifferentiableAttr->m_mapTypeToIDifferentiableWitness)
- {
- if (auto aggTypeDeclRef = type.Key.as<AggTypeDecl>())
- {
- maybeRegisterDifferentiableTypeRecursive(
- m_astBuilder,
- m_astBuilder->getOrCreateDeclRefType(
- aggTypeDeclRef.getDecl(), aggTypeDeclRef.substitutions),
- workingSet);
- }
- }
- }
-
Expr* SemanticsVisitor::CheckTerm(Expr* term)
{
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index 165c84192..11aacd255 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -759,9 +759,8 @@ namespace Slang
/// Registers a type as conforming to IDifferentiable, along with a witness
/// describing the relationship.
///
- void registerDifferentiableType(DeclRefType* type, SubtypeWitness* witness);
- void maybeRegisterDifferentiableTypeRecursive(ASTBuilder* builder, Type* type, ValSet& workingSet);
- void completeDifferentiableTypeDictionary();
+ void addDifferentiableTypeToDiffTypeRegistry(DeclRefType* type, SubtypeWitness* witness);
+ void maybeRegisterDifferentiableTypeImplRecursive(ASTBuilder* builder, Type* type);
// Construct the differential for 'type', if it exists.
Type* getDifferentialType(ASTBuilder* builder, Type* type, SourceLoc loc);
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index 564c33268..b2f420730 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -283,8 +283,11 @@ InstPair ForwardDiffTranscriber::transcribeConstruct(IRBuilder* builder, IRInst*
else
{
auto operandDataType = origConstruct->getOperand(ii)->getDataType();
- operandDataType = (IRType*)findOrTranscribePrimalInst(builder, operandDataType);
- diffOperands.add(getDifferentialZeroOfType(builder, operandDataType));
+ if (auto diffOperandType = differentiateType(builder, operandDataType))
+ {
+ operandDataType = (IRType*)findOrTranscribePrimalInst(builder, operandDataType);
+ diffOperands.add(getDifferentialZeroOfType(builder, operandDataType));
+ }
}
}
@@ -293,7 +296,7 @@ InstPair ForwardDiffTranscriber::transcribeConstruct(IRBuilder* builder, IRInst*
builder->emitIntrinsicInst(
diffConstructType,
origConstruct->getOp(),
- operandCount,
+ diffOperands.getCount(),
diffOperands.getBuffer()));
}
else
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp
index a8e06bf91..73d9b6ba6 100644
--- a/source/slang/slang-ir-autodiff-transcriber-base.cpp
+++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp
@@ -21,15 +21,14 @@ void AutoDiffTranscriberBase::mapDifferentialInst(IRInst* origInst, IRInst* diff
{
if (hasDifferentialInst(origInst))
{
- if (lookupDiffInst(origInst) != diffInst)
+ auto existingDiffInst = lookupDiffInst(origInst);
+ if (existingDiffInst != diffInst)
{
SLANG_UNEXPECTED("Inconsistent differential mappings");
}
}
- else
- {
- instMapD.Add(origInst, diffInst);
- }
+
+ instMapD[origInst] = diffInst;
}
void AutoDiffTranscriberBase::mapPrimalInst(IRInst* origInst, IRInst* primalInst)