diff options
| author | Yong He <yonghe@outlook.com> | 2023-02-24 13:22:07 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-02-24 13:22:07 -0800 |
| commit | 91694dacdb8d3ab7dd9783d7c0c43629bf11f578 (patch) | |
| tree | 19eb8db8845b22c379ebc3f114d610c1e401bd9d /source/slang | |
| parent | bd6306cdaa4a49344658bd026721b6532e103d09 (diff) | |
Fix differential type registration through non-differentiable type. (#2677)
* Fix differential type registration through non-differentiable type.
* More fix.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang')
| -rw-r--r-- | source/slang/slang-ast-base.h | 35 | ||||
| -rw-r--r-- | source/slang/slang-ast-builder.h | 34 | ||||
| -rw-r--r-- | source/slang/slang-ast-modifier.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 1 | ||||
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 98 | ||||
| -rw-r--r-- | source/slang/slang-check-impl.h | 5 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.cpp | 9 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-transcriber-base.cpp | 9 |
8 files changed, 77 insertions, 116 deletions
diff --git a/source/slang/slang-ast-base.h b/source/slang/slang-ast-base.h index 627a56152..1b8b221ef 100644 --- a/source/slang/slang-ast-base.h +++ b/source/slang/slang-ast-base.h @@ -150,6 +150,41 @@ class Val : public NodeBase HashCode _getHashCodeOverride(); }; +struct ValSet +{ + struct ValItem + { + Val* val = nullptr; + ValItem() = default; + ValItem(Val* v) : val(v) {} + + HashCode getHashCode() + { + return val ? val->getHashCode() : 0; + } + bool operator==(ValItem other) + { + if (val == other.val) + return true; + if (val) + return val->equalsVal(other.val); + else if (other.val) + return other.val->equalsVal(val); + return false; + } + }; + HashSet<ValItem> set; + bool add(Val* val) + { + return set.Add(ValItem(val)); + } + bool contains(Val* val) + { + return set.Contains(ValItem(val)); + } +}; + + SLANG_FORCE_INLINE StringBuilder& operator<<(StringBuilder& io, Val* val) { SLANG_ASSERT(val); val->toText(io); return io; } /// Given a `value` that refers to a `param` of some generic, attempt to apply diff --git a/source/slang/slang-ast-builder.h b/source/slang/slang-ast-builder.h index d44a15813..c5e9e5429 100644 --- a/source/slang/slang-ast-builder.h +++ b/source/slang/slang-ast-builder.h @@ -424,40 +424,6 @@ protected: }; -struct ValSet -{ - struct ValItem - { - Val* val = nullptr; - ValItem() = default; - ValItem(Val* v) : val(v) {} - - HashCode getHashCode() - { - return val ? val->getHashCode() : 0; - } - bool operator==(ValItem other) - { - if (val == other.val) - return true; - if (val) - return val->equalsVal(other.val); - else if (other.val) - return other.val->equalsVal(val); - return false; - } - }; - HashSet<ValItem> set; - bool add(Val* val) - { - return set.Add(ValItem(val)); - } - bool contains(Val* val) - { - return set.Contains(ValItem(val)); - } -}; - } // namespace Slang #endif diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index 6ac464784..7dd0819d8 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -1045,6 +1045,8 @@ class DifferentiableAttribute : public Attribute /// Mapping from types to subtype witnesses for conformance to IDifferentiable. OrderedDictionary<DeclRefBase, SubtypeWitness*> m_mapTypeToIDifferentiableWitness; + + SLANG_UNREFLECTED ValSet m_typeRegistrationWorkingSet; }; class DllImportAttribute : public Attribute diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 381efa2c7..142842e12 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -5037,7 +5037,6 @@ namespace Slang auto thisType = calcThisType(parentDeclRef); maybeRegisterDifferentiableType(m_astBuilder, thisType); } - completeDifferentiableTypeDictionary(); m_parentDifferentiableAttr = oldAttr; } diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 3567e2593..2803b5959 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -923,7 +923,7 @@ namespace Slang return result; } - void SemanticsVisitor::registerDifferentiableType(DeclRefType* type, SubtypeWitness* witness) + void SemanticsVisitor::addDifferentiableTypeToDiffTypeRegistry(DeclRefType* type, SubtypeWitness* witness) { SLANG_RELEASE_ASSERT(m_parentDifferentiableAttr); if (witness) @@ -944,49 +944,23 @@ namespace Slang return; } - // Check for special cases such as PtrTypeBase<T> or Array<T> - // This could potentially be handled later by simply defining extensions - // for Ptr<T:IDifferentiable> etc.. - // - if (auto ptrType = as<PtrTypeBase>(type)) - { - maybeRegisterDifferentiableType(builder, ptrType->getValueType()); - return; - } - - if (auto arrayType = as<ArrayExpressionType>(type)) - { - maybeRegisterDifferentiableType(builder, arrayType->getElementType()); - // Fall through to register the array type itself. - } - - if (auto declRefType = as<DeclRefType>(type)) - { - if (auto subtypeWitness = as<SubtypeWitness>( - tryGetInterfaceConformanceWitness(type, getASTBuilder()->getDifferentiableInterface()))) - { - registerDifferentiableType((DeclRefType*)type, subtypeWitness); - } - return; - } + maybeRegisterDifferentiableTypeImplRecursive(builder, type); } - void SemanticsVisitor::maybeRegisterDifferentiableTypeRecursive(ASTBuilder* builder, Type* type, ValSet& workingSet) + void SemanticsVisitor::maybeRegisterDifferentiableTypeImplRecursive(ASTBuilder* builder, Type* type) { - if (workingSet.contains(type)) + // Recursively visit the tree of type and register all differentiable types along the way. + + if (as<TypeType>(type)) return; - - if (!builder->isDifferentiableInterfaceAvailable()) - { + if (!type) return; - } - if (!m_parentDifferentiableAttr) - { + // Have we already registered this type? If so we can exit now. + if (m_parentDifferentiableAttr->m_typeRegistrationWorkingSet.contains(type)) return; - } - workingSet.add(type); + m_parentDifferentiableAttr->m_typeRegistrationWorkingSet.add(type); // Check for special cases such as PtrTypeBase<T> or Array<T> // This could potentially be handled later by simply defining extensions @@ -994,13 +968,13 @@ namespace Slang // if (auto ptrType = as<PtrTypeBase>(type)) { - maybeRegisterDifferentiableTypeRecursive(builder, ptrType->getValueType(), workingSet); + maybeRegisterDifferentiableTypeImplRecursive(builder, ptrType->getValueType()); return; } if (auto arrayType = as<ArrayExpressionType>(type)) { - maybeRegisterDifferentiableTypeRecursive(builder, arrayType->getElementType(), workingSet); + maybeRegisterDifferentiableTypeImplRecursive(builder, arrayType->getElementType()); // Fall through to register the array type itself. } @@ -1009,20 +983,20 @@ namespace Slang if (auto subtypeWitness = as<SubtypeWitness>( tryGetInterfaceConformanceWitness(type, getASTBuilder()->getDifferentiableInterface()))) { - registerDifferentiableType((DeclRefType*)type, subtypeWitness); - if (auto aggTypeDeclRef = declRefType->declRef.as<AggTypeDecl>()) - { - foreachDirectOrExtensionMemberOfType<InheritanceDecl>(this, aggTypeDeclRef, [&](DeclRef<InheritanceDecl> member) - { - auto subType = m_astBuilder->getOrCreateDeclRefType(member.getDecl(), nullptr); - maybeRegisterDifferentiableTypeRecursive(m_astBuilder, subType, workingSet); - }); - foreachDirectOrExtensionMemberOfType<VarDeclBase>(this, aggTypeDeclRef, [&](DeclRef<VarDeclBase> member) - { - auto fieldType = getType(m_astBuilder, member); - maybeRegisterDifferentiableTypeRecursive(m_astBuilder, fieldType, workingSet); - }); - } + addDifferentiableTypeToDiffTypeRegistry((DeclRefType*)type, subtypeWitness); + } + if (auto aggTypeDeclRef = declRefType->declRef.as<AggTypeDecl>()) + { + foreachDirectOrExtensionMemberOfType<InheritanceDecl>(this, aggTypeDeclRef, [&](DeclRef<InheritanceDecl> member) + { + auto subType = m_astBuilder->getOrCreateDeclRefType(member.getDecl(), nullptr); + maybeRegisterDifferentiableTypeImplRecursive(m_astBuilder, subType); + }); + foreachDirectOrExtensionMemberOfType<VarDeclBase>(this, aggTypeDeclRef, [&](DeclRef<VarDeclBase> member) + { + auto fieldType = getType(m_astBuilder, member); + maybeRegisterDifferentiableTypeImplRecursive(m_astBuilder, fieldType); + }); } for (auto subst = declRefType->declRef.substitutions.substitutions; subst; subst = subst->outer) { @@ -1032,35 +1006,19 @@ namespace Slang { if (auto typeArg = as<Type>(arg)) { - maybeRegisterDifferentiableTypeRecursive(m_astBuilder, typeArg, workingSet); + maybeRegisterDifferentiableTypeImplRecursive(m_astBuilder, typeArg); } } } else if (auto thisSubst = as<ThisTypeSubstitution>(subst)) { - maybeRegisterDifferentiableTypeRecursive(m_astBuilder, thisSubst->witness->sub, workingSet); + maybeRegisterDifferentiableTypeImplRecursive(m_astBuilder, thisSubst->witness->sub); } } return; } } - void SemanticsVisitor::completeDifferentiableTypeDictionary() - { - ValSet workingSet; - for (auto type : m_parentDifferentiableAttr->m_mapTypeToIDifferentiableWitness) - { - if (auto aggTypeDeclRef = type.Key.as<AggTypeDecl>()) - { - maybeRegisterDifferentiableTypeRecursive( - m_astBuilder, - m_astBuilder->getOrCreateDeclRefType( - aggTypeDeclRef.getDecl(), aggTypeDeclRef.substitutions), - workingSet); - } - } - } - Expr* SemanticsVisitor::CheckTerm(Expr* term) { diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 165c84192..11aacd255 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -759,9 +759,8 @@ namespace Slang /// Registers a type as conforming to IDifferentiable, along with a witness /// describing the relationship. /// - void registerDifferentiableType(DeclRefType* type, SubtypeWitness* witness); - void maybeRegisterDifferentiableTypeRecursive(ASTBuilder* builder, Type* type, ValSet& workingSet); - void completeDifferentiableTypeDictionary(); + void addDifferentiableTypeToDiffTypeRegistry(DeclRefType* type, SubtypeWitness* witness); + void maybeRegisterDifferentiableTypeImplRecursive(ASTBuilder* builder, Type* type); // Construct the differential for 'type', if it exists. Type* getDifferentialType(ASTBuilder* builder, Type* type, SourceLoc loc); diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index 564c33268..b2f420730 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -283,8 +283,11 @@ InstPair ForwardDiffTranscriber::transcribeConstruct(IRBuilder* builder, IRInst* else { auto operandDataType = origConstruct->getOperand(ii)->getDataType(); - operandDataType = (IRType*)findOrTranscribePrimalInst(builder, operandDataType); - diffOperands.add(getDifferentialZeroOfType(builder, operandDataType)); + if (auto diffOperandType = differentiateType(builder, operandDataType)) + { + operandDataType = (IRType*)findOrTranscribePrimalInst(builder, operandDataType); + diffOperands.add(getDifferentialZeroOfType(builder, operandDataType)); + } } } @@ -293,7 +296,7 @@ InstPair ForwardDiffTranscriber::transcribeConstruct(IRBuilder* builder, IRInst* builder->emitIntrinsicInst( diffConstructType, origConstruct->getOp(), - operandCount, + diffOperands.getCount(), diffOperands.getBuffer())); } else diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index a8e06bf91..73d9b6ba6 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -21,15 +21,14 @@ void AutoDiffTranscriberBase::mapDifferentialInst(IRInst* origInst, IRInst* diff { if (hasDifferentialInst(origInst)) { - if (lookupDiffInst(origInst) != diffInst) + auto existingDiffInst = lookupDiffInst(origInst); + if (existingDiffInst != diffInst) { SLANG_UNEXPECTED("Inconsistent differential mappings"); } } - else - { - instMapD.Add(origInst, diffInst); - } + + instMapD[origInst] = diffInst; } void AutoDiffTranscriberBase::mapPrimalInst(IRInst* origInst, IRInst* primalInst) |
