summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
Diffstat (limited to 'source')
-rw-r--r--source/slang/core.meta.slang3
-rw-r--r--source/slang/slang-ast-decl.h30
-rw-r--r--source/slang/slang-ast-dump.cpp29
-rw-r--r--source/slang/slang-ast-modifier.h3
-rw-r--r--source/slang/slang-ast-support-types.h1
-rw-r--r--source/slang/slang-check-decl.cpp186
-rw-r--r--source/slang/slang-check-expr.cpp30
-rw-r--r--source/slang/slang-check-impl.h114
-rw-r--r--source/slang/slang-check-type.cpp13
-rw-r--r--source/slang/slang-emit-cpp.cpp36
-rw-r--r--source/slang/slang-ir-dce.cpp7
-rw-r--r--source/slang/slang-ir-diff-jvp.cpp451
-rw-r--r--source/slang/slang-ir-inst-defs.h4
-rw-r--r--source/slang/slang-ir-insts.h28
-rw-r--r--source/slang/slang-ir-link.cpp29
-rw-r--r--source/slang/slang-ir-util.cpp33
-rw-r--r--source/slang/slang-ir-util.h8
-rw-r--r--source/slang/slang-ir.cpp122
-rw-r--r--source/slang/slang-ir.h1
-rw-r--r--source/slang/slang-lower-to-ir.cpp92
-rw-r--r--source/slang/slang-serialize-ast-type-info.h3
21 files changed, 325 insertions, 898 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index 5df9d01fe..ce52dbb56 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -2737,9 +2737,6 @@ __attributeTarget(InterfaceDecl)
attribute_syntax [Specialize] : SpecializeAttribute;
__attributeTarget(DeclBase)
-attribute_syntax [Differentiable] : DifferentiableAttribute;
-
-__attributeTarget(DeclBase)
attribute_syntax [DerivativeMember(memberName)] : DerivativeMemberAttribute;
enum _BuiltinRequirementKind
diff --git a/source/slang/slang-ast-decl.h b/source/slang/slang-ast-decl.h
index 90175dd9d..cbd3f0f0c 100644
--- a/source/slang/slang-ast-decl.h
+++ b/source/slang/slang-ast-decl.h
@@ -511,36 +511,6 @@ class AttributeDecl : public ContainerDecl
SyntaxClass<NodeBase> syntaxClass;
};
-// A declaration to hold differentiable type conformances generated during
-// the semantic checking phase.
-//
-class DifferentiableTypeDictionary : public ContainerDecl
-{
- SLANG_AST_CLASS(DifferentiableTypeDictionary);
-};
-
-// A declaration to hold differentiable type conformances generated during
-// the semantic checking phase.
-//
-class DifferentiableTypeDictionaryItem : public Decl
-{
- SLANG_AST_CLASS(DifferentiableTypeDictionaryItem);
-
- DeclRefType* baseType;
- SubtypeWitness* confWitness;
-};
-
-// A declaration that references another dictionary (generally from another module)
-// Used to tell the IR lowering pass to process the referenced dictionary.
-//
-class DifferentiableTypeDictionaryImportItem : public Decl
-{
- SLANG_AST_CLASS(DifferentiableTypeDictionaryImportItem);
-
- DeclRef<DifferentiableTypeDictionary> dictionaryRef;
-};
-
-
bool isInterfaceRequirement(Decl* decl);
} // namespace Slang
diff --git a/source/slang/slang-ast-dump.cpp b/source/slang/slang-ast-dump.cpp
index 455a9db74..fc3c015e0 100644
--- a/source/slang/slang-ast-dump.cpp
+++ b/source/slang/slang-ast-dump.cpp
@@ -408,6 +408,35 @@ struct ASTDumpContext
m_writer->emit("}");
}
+ template <typename KEY, typename VALUE>
+ void dump(const OrderedDictionary<KEY, VALUE>& dict)
+ {
+ m_writer->emit(" { \n");
+ m_writer->indent();
+
+ for (auto iter : dict)
+ {
+ const auto& key = iter.Key;
+ const auto& value = iter.Value;
+
+ dump(key);
+ m_writer->emit(" : ");
+ dump(value);
+
+ m_writer->emit("\n");
+ }
+
+ m_writer->dedent();
+ m_writer->emit("}");
+ }
+
+ void dump(DeclRefBase declRef)
+ {
+ StringBuilder sb;
+ sb << declRef;
+ m_writer->emit(sb.ToString());
+ }
+
void dump(const DeclCheckStateExt& extState)
{
auto state = extState.getState();
diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h
index 0c1eb8d49..67ff297dc 100644
--- a/source/slang/slang-ast-modifier.h
+++ b/source/slang/slang-ast-modifier.h
@@ -977,6 +977,9 @@ class SpecializeAttribute : public Attribute
class DifferentiableAttribute : public Attribute
{
SLANG_AST_CLASS(DifferentiableAttribute)
+
+ /// Mapping from types to subtype witnesses for conformance to IDifferentiable.
+ OrderedDictionary<DeclRefBase, SubtypeWitness*> m_mapTypeToIDifferentiableWitness;
};
class DllImportAttribute : public Attribute
diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h
index d4a781846..015e6969c 100644
--- a/source/slang/slang-ast-support-types.h
+++ b/source/slang/slang-ast-support-types.h
@@ -1510,7 +1510,6 @@ namespace Slang
DAddFunc, ///< The `IDifferentiable.dadd` function requirement
DMulFunc, ///< The `IDifferentiable.dmul` function requirement
};
-
} // namespace Slang
#endif
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index f60fbcc2c..7140d541a 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -903,14 +903,6 @@ namespace Slang
// If `decl` is a container, then we want to ensure its children.
if(auto containerDecl = as<ContainerDecl>(decl))
{
- bool trackDiffTypes = (as<GenericDecl>(decl) != nullptr);
- if (trackDiffTypes)
- {
- // Add a context to track differentiable types.
- DifferentiableTypeSemanticContext subDiffTypeContext;
- visitor->getShared()->pushDiffTypeContext(&subDiffTypeContext);
- }
-
// NOTE! We purposefully do not iterate with the for(auto childDecl : containerDecl->members) here,
// because the visitor may add to `members` whilst iteration takes place, invalidating the iterator
// and likely a crash.
@@ -932,22 +924,6 @@ namespace Slang
_ensureAllDeclsRec(visitor, childDecl, state);
}
-
- if (trackDiffTypes)
- {
- auto subDiffTypeContext = visitor->getShared()->popDiffTypeContext();
-
- // If there were any differentiable types used in differentiable
- // methods, generate a dictionary with the required info.
- //
- if (subDiffTypeContext->isDictionaryRequired())
- {
- auto diffTypeDict = subDiffTypeContext->makeDifferentiableTypeDictionaryNode(visitor->getASTBuilder());
- diffTypeDict->parentDecl = containerDecl;
- containerDecl->members.add(diffTypeDict);
- containerDecl->invalidateMemberDictionary();
- }
- }
}
// Note: the "inner" declaration of a `GenericDecl` is currently
@@ -1541,49 +1517,6 @@ namespace Slang
return false;
}
- void SemanticsVisitor::tryAddDifferentiableConformanceToContext(Decl* decl, DifferentiableTypeSemanticContext*)
- {
- // If the autodiff core library (diff.meta.slang) has not been loaded yet, ignore any
- // request to check differentiable types.
- //
- if (!m_astBuilder->isDifferentiableInterfaceAvailable())
- return;
-
- auto diffInterface = m_astBuilder->getDifferentiableInterface();
-
- DeclRefType* type = nullptr;
-
- if (auto extensionDecl = as<ExtensionDecl>(decl))
- {
- // If this is an extension, use the provided target type.
- type = as<DeclRefType>(extensionDecl->targetType.type);
- }
- else
- {
- // If this is a type declaration, create a decl ref without
- // any substitutions.
- //
- auto declRef = makeDeclRef(decl);
-
- // TODO: Strip substitutions from the declreftype
- type = DeclRefType::create(m_astBuilder, declRef);
- }
-
- // Skip if the declaration is the interface itself.
- if (type->declRef == diffInterface)
- return;
-
- // If the DeclRefType conforms to IDifferentiable, register it with the top-level
- // context.
- //
- if (auto witness = as<SubtypeWitness>(tryGetInterfaceConformanceWitness(type, diffInterface)))
- {
- // TODO: Temporarily disabled to move to new system. Fix later.
- // context->registerDifferentiableType(type, witness);
- }
-
- }
-
void SemanticsDeclHeaderVisitor::visitGenericTypeConstraintDecl(GenericTypeConstraintDecl* decl)
{
// TODO: are there any other validations we can do at this point?
@@ -1637,23 +1570,6 @@ namespace Slang
ensureDecl(constraint, DeclCheckState::ReadyForReference);
}
}
-
- // TODO(sai): Is this the right checking stage to be doing this?
- DifferentiableTypeSemanticContext diffTypeContext;
-
- for (Index i = 0; i < members.getCount(); ++i)
- {
- Decl* m = members[i];
-
- if (auto typeParam = as<GenericTypeParamDecl>(m))
- {
- tryAddDifferentiableConformanceToContext(typeParam, &diffTypeContext);
- }
- }
-
- auto diffTypeDictionaryNode = diffTypeContext.makeDifferentiableTypeDictionaryNode(m_astBuilder);
- diffTypeDictionaryNode->parentDecl = genericDecl;
- genericDecl->members.add(diffTypeDictionaryNode);
}
void SemanticsDeclBasesVisitor::visitInheritanceDecl(InheritanceDecl* inheritanceDecl)
@@ -1689,7 +1605,6 @@ namespace Slang
void visitAggTypeDecl(AggTypeDecl* aggTypeDecl)
{
checkAggTypeConformance(aggTypeDecl);
- tryAddDifferentiableConformanceToContext(aggTypeDecl, getShared()->getDiffTypeContext());
}
// Conformances can also come via `extension` declarations, and
@@ -1698,7 +1613,6 @@ namespace Slang
void visitExtensionDecl(ExtensionDecl* extensionDecl)
{
checkExtensionConformance(extensionDecl);
- tryAddDifferentiableConformanceToContext(extensionDecl, getShared()->getDiffTypeContext());
}
};
@@ -1855,32 +1769,6 @@ namespace Slang
// Furthermore, because a fully checked function will have checked
// its body, this also means that all function bodies and the
// declarations they contain should be fully checked.
-
- // Generate a dictionary node to hold information about all
- // available differentiable types in scope (including imports and stdlib)
- //
- if (getShared()->getDiffTypeContext()->isDictionaryRequired())
- finishDifferentiableTypeDictionary(moduleDecl);
- }
-
- void SemanticsVisitor::finishDifferentiableTypeDictionary(ModuleDecl* moduleDecl)
- {
- // Grab the differentiable type information from imported modules.
- for(auto importedModule : getShared()->importedModulesList)
- {
- this->getShared()->getDiffTypeContext()->addImportedModule(importedModule);
- }
-
- // Grad the differentiable type information from the standard library modules.
- for (auto stdLibModule : this->getSession()->stdlibModules)
- {
- this->getShared()->getDiffTypeContext()->addImportedModule(stdLibModule->getModuleDecl());
- }
-
- auto diffTypeDictNode = this->getShared()->getDiffTypeContext()->makeDifferentiableTypeDictionaryNode(m_astBuilder);
- diffTypeDictNode->parentDecl = moduleDecl;
-
- moduleDecl->members.add(diffTypeDictNode);
}
bool SemanticsVisitor::doesSignatureMatchRequirement(
@@ -5374,11 +5262,6 @@ namespace Slang
void SemanticsDeclHeaderVisitor::checkCallableDeclCommon(CallableDecl* decl)
{
- if (decl->findModifier<DifferentiableAttribute>())
- {
- this->getShared()->getDiffTypeContext()->requireDifferentiableTypeDictionary();
- }
-
for(auto paramDecl : decl->getParameters())
{
ensureDecl(paramDecl, DeclCheckState::ReadyForReference);
@@ -6249,75 +6132,6 @@ namespace Slang
m_mapTypeDeclToCandidateExtensions.Clear();
}
- void DifferentiableTypeSemanticContext::registerDifferentiableType(DeclRefType* type, SubtypeWitness* witness)
- {
- // Need to generate a type dictionary since we have a declaration that works with
- // a differentiable type.
- //
- this->requireDifferentiableTypeDictionary();
-
- m_mapTypeToIDifferentiableWitness.AddIfNotExists(DeclRefTypeKey(type), witness);
- }
-
- List<KeyValuePair<DeclRefType*, SubtypeWitness*>> DifferentiableTypeSemanticContext::getDifferentiableTypeConformanceList()
- {
- List<KeyValuePair<DeclRefType*, SubtypeWitness*>> diffConformances;
- for (auto entry : m_mapTypeToIDifferentiableWitness)
- {
- diffConformances.add(KeyValuePair<DeclRefType*, SubtypeWitness*>(entry.Key.type, entry.Value));
- }
-
- return diffConformances;
- }
-
- DifferentiableTypeDictionary* DifferentiableTypeSemanticContext::makeDifferentiableTypeDictionaryNode(
- ASTBuilder* builder)
- {
- auto dictionary = builder->create<DifferentiableTypeDictionary>();
-
- for (auto item : m_mapTypeToIDifferentiableWitness)
- {
- auto entry = builder->create<DifferentiableTypeDictionaryItem>();
- entry->baseType = item.Key.type;
- entry->confWitness = item.Value;
- entry->parentDecl = dictionary;
-
- dictionary->members.add(entry);
- }
-
- for (auto item : m_importedDictionaries)
- {
- auto entry = builder->create<DifferentiableTypeDictionaryImportItem>();
- entry->dictionaryRef = item;
- entry->parentDecl = dictionary;
-
- dictionary->members.add(entry);
- }
-
- return dictionary;
- }
-
- void DifferentiableTypeSemanticContext::addImportedModule(ModuleDecl* importedModuleDecl)
- {
- // TODO: This is a terribly slow way to find the diff type dictionary.
- // Switch to lookUp() when possible (this might involve naming the dictionary something)
- //
- for (auto diffTypeDict : importedModuleDecl->getMembersOfType<DifferentiableTypeDictionary>())
- {
- m_importedDictionaries.add(makeDeclRef(diffTypeDict));
- }
- }
-
- void DifferentiableTypeSemanticContext::requireDifferentiableTypeDictionary()
- {
- this->m_isTypeDictionaryRequired = true;
- }
-
- bool DifferentiableTypeSemanticContext::isDictionaryRequired()
- {
- return this->m_isTypeDictionaryRequired;
- }
-
void SharedSemanticsContext::_addCandidateExtensionsFromModule(ModuleDecl* moduleDecl)
{
for( auto& entry : moduleDecl->mapTypeToCandidateExtensions )
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 251849ede..ad199300a 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -899,6 +899,16 @@ namespace Slang
return result;
}
+ void SemanticsVisitor::registerDifferentiableType(DeclRefType* type, SubtypeWitness* witness)
+ {
+ SLANG_RELEASE_ASSERT(m_parentDifferentiableAttr);
+ if (witness)
+ {
+ m_parentDifferentiableAttr->m_mapTypeToIDifferentiableWitness.AddIfNotExists(type->declRef, witness);
+ }
+ }
+
+
void SemanticsVisitor::maybeRegisterDifferentiableType(ASTBuilder* builder, Type* type)
{
if (!builder->isDifferentiableInterfaceAvailable())
@@ -906,6 +916,11 @@ namespace Slang
return;
}
+ if (!m_parentDifferentiableAttr)
+ {
+ return;
+ }
+
// Check for special cases such as PtrTypeBase<T> or Array<T>
// This could potentially be handled later by simply defining extensions
// for Ptr<T:IDifferentiable> etc..
@@ -927,10 +942,8 @@ namespace Slang
if (auto subtypeWitness = as<SubtypeWitness>(
tryGetInterfaceConformanceWitness(type, getASTBuilder()->getDifferentiableInterface())))
{
- auto diffTypeContext = this->getShared()->innermostDiffTypeContext();
- diffTypeContext->registerDifferentiableType((DeclRefType*)type, subtypeWitness);
+ registerDifferentiableType((DeclRefType*)type, subtypeWitness);
}
-
return;
}
}
@@ -2007,20 +2020,9 @@ namespace Slang
Expr* SemanticsExprVisitor::visitForwardDifferentiateExpr(ForwardDifferentiateExpr* expr)
{
- this->getShared()->getDiffTypeContext()->requireDifferentiableTypeDictionary();
-
// Check/Resolve inner function declaration.
expr->baseFunction = CheckTerm(expr->baseFunction);
- // Register parameter types.
- if (auto funcType = as<FuncType>(expr->baseFunction->type.type))
- {
- for (UInt i = 0; i < funcType->getParamCount(); i++)
- {
- maybeRegisterDifferentiableType(m_astBuilder, funcType->getParamType(i));
- }
- }
-
// For now we only support using higher order expr as callee in an invoke expr.
// The actual type of the higher order function will be derived during resolve invoke.
expr->type = m_astBuilder->getBottomType();
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index 33455e42d..a0141911a 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -214,75 +214,6 @@ namespace Slang
Dictionary<OperatorOverloadCacheKey, OverloadCandidate> resolvedOperatorOverloadCache;
Dictionary<BasicTypeKeyPair, ConversionCost> conversionCostCache;
};
-
- struct DifferentiableTypeSemanticContext
- {
-
- public:
- /// Registers a type as conforming to IDifferentiable, along with a witness
- /// describing the relationship.
- ///
- void registerDifferentiableType(DeclRefType* type, SubtypeWitness* witness);
-
- /// Returns the list of registered differentiable types.
- List<KeyValuePair<DeclRefType*, SubtypeWitness*>> getDifferentiableTypeConformanceList();
-
- /// Creates a DifferentiableTypeDictionary AST container node with an entry for
- /// every registered type. This can be inserted into the appropriate context for the
- /// auto-diff pass.
- ///
- DifferentiableTypeDictionary* makeDifferentiableTypeDictionaryNode(ASTBuilder* builder);
-
- /// Creates a DifferentiableTypeDictionary AST container node with an entry for
- /// every registered type. This can be inserted into the appropriate context for the
- /// auto-diff pass.
- ///
- void addImportedModule(ModuleDecl* importedModuleDecl);
-
- /// Set flag to indicate that the type dictionary is requried.
- void requireDifferentiableTypeDictionary();
-
- /// Returns flag indicating whether the type dictionary is requried.
- bool isDictionaryRequired();
-
- private:
- // Nested struct to override the '==' operator for DeclRefTypes
- struct DeclRefTypeKey
- {
- DeclRefType* type;
-
- DeclRefTypeKey(DeclRefType* type) : type(type)
- {};
-
- DeclRefTypeKey(DeclRefTypeKey& typeKey) : type(typeKey.type)
- {};
-
- DeclRefTypeKey() : type(nullptr)
- {};
-
- bool operator==(const DeclRefTypeKey& other) const
- {
- return (other.type->declRef == this->type->declRef);
- }
-
- HashCode getHashCode() const
- {
- Hasher hasher;
- hasher.hashObject(&type->declRef);
- return hasher.getResult();
- }
- };
-
- /// Mapping from types to subtype witnesses for conformance to IDifferentiable.
- OrderedDictionary<DeclRefTypeKey, SubtypeWitness*> m_mapTypeToIDifferentiableWitness;
-
- /// List of external dictionaries (from imported modules)
- List<DeclRef<DifferentiableTypeDictionary>> m_importedDictionaries;
-
- /// Flag to indicate if a differentiable type dictionary is required.
- bool m_isTypeDictionaryRequired = false;
- };
-
/// Shared state for a semantics-checking session.
struct SharedSemanticsContext
{
@@ -310,11 +241,6 @@ namespace Slang
//
List<ModuleDecl*> importedModulesList;
HashSet<ModuleDecl*> importedModulesSet;
-
- DifferentiableTypeSemanticContext diffTypeContext;
-
- List<DifferentiableTypeSemanticContext*> diffTypeContextStack;
-
public:
SharedSemanticsContext(
Linkage* linkage,
@@ -349,28 +275,6 @@ namespace Slang
return false;
}
- DifferentiableTypeSemanticContext* getDiffTypeContext()
- {
- return &diffTypeContext;
- }
-
- DifferentiableTypeSemanticContext* innermostDiffTypeContext()
- {
- return (diffTypeContextStack.getCount() > 0) ? diffTypeContextStack.getLast() : &diffTypeContext;
- }
-
- void pushDiffTypeContext(DifferentiableTypeSemanticContext* context)
- {
- diffTypeContextStack.add(context);
- }
-
- DifferentiableTypeSemanticContext* popDiffTypeContext()
- {
- auto context = diffTypeContextStack.getLast();
- diffTypeContextStack.removeLast();
- return context;
- }
-
/// Get the list of extension declarations that appear to apply to `decl` in this context
List<ExtensionDecl*> const& getCandidateExtensionsForTypeDecl(AggTypeDecl* decl);
@@ -431,6 +335,7 @@ namespace Slang
SemanticsContext result(*this);
result.m_parentFunc = parentFunc;
result.m_outerStmts = nullptr;
+ result.m_parentDifferentiableAttr = parentFunc->findModifier<DifferentiableAttribute>();
return result;
}
@@ -519,6 +424,8 @@ namespace Slang
/// The parent function (if any) that surrounds the statement being checked.
FunctionDeclBase* m_parentFunc = nullptr;
+ DifferentiableAttribute* m_parentDifferentiableAttr = nullptr;
+
/// The linked list of lexically surrounding statements.
OuterStmtInfo* m_outerStmts = nullptr;
@@ -801,6 +708,11 @@ namespace Slang
// Convert a function's original type to it's JVP type.
Type* processJVPFuncType(FuncType* originalType);
+ /// Registers a type as conforming to IDifferentiable, along with a witness
+ /// describing the relationship.
+ ///
+ void registerDifferentiableType(DeclRefType* type, SubtypeWitness* witness);
+
// Check and register a type if it is differentiable.
void maybeRegisterDifferentiableType(ASTBuilder* builder, Type* type);
@@ -1129,16 +1041,6 @@ namespace Slang
DeclRef<AssocTypeDecl> requirementDeclRef,
RefPtr<WitnessTable> witnessTable);
- /// Registers a type as differentiable in the currrent semantic context, if the declaration represents
- /// a subtype of IDifferentable. Does nothing otherwise.
- void tryAddDifferentiableConformanceToContext(
- Decl* decl,
- DifferentiableTypeSemanticContext* context);
-
- /// Generates a dictionary node for the module with all registered differentiable types,
- /// as well as information about differentiable types in imported modules.
- void finishDifferentiableTypeDictionary(ModuleDecl* moduleDecl);
-
// Find the appropriate member of a declared type to
// satisfy a requirement of an interface the type
// claims to conform to.
diff --git a/source/slang/slang-check-type.cpp b/source/slang/slang-check-type.cpp
index 6bc4b9d36..d402dde03 100644
--- a/source/slang/slang-check-type.cpp
+++ b/source/slang/slang-check-type.cpp
@@ -320,19 +320,6 @@ namespace Slang
getSink()->diagnose(typeExp.exp, Diagnostics::cannotDefinePtrTypeToManagedResource);
}
}
-
- // Differentiable type checking.
- // TODO: This can be super slow. Switch to caching the result asap.
- if (this->m_parentFunc &&
- this->m_parentFunc->findModifier<ForwardDifferentiableAttribute>())
- {
- auto diffTypeContext = this->getShared()->innermostDiffTypeContext();
- if (auto subtypeWitness = as<SubtypeWitness>(
- tryGetInterfaceConformanceWitness(result, getASTBuilder()->getDifferentiableInterface())))
- {
- diffTypeContext->registerDifferentiableType((DeclRefType*)result, subtypeWitness);
- }
- }
*outProperType = result;
return true;
diff --git a/source/slang/slang-emit-cpp.cpp b/source/slang/slang-emit-cpp.cpp
index dd3acff78..f62007bb0 100644
--- a/source/slang/slang-emit-cpp.cpp
+++ b/source/slang/slang-emit-cpp.cpp
@@ -7,6 +7,7 @@
#include "slang-mangled-lexer.h"
#include "slang-ir-clone.h"
+#include "slang-ir-util.h"
#include "../compiler-core/slang-artifact-desc-util.h"
@@ -80,39 +81,6 @@ static UnownedStringSlice _getTypePrefix(IROp op)
}
}
-static IROp _getTypeStyle(IROp op)
-{
- switch (op)
- {
- case kIROp_VoidType:
- case kIROp_BoolType:
- {
- return op;
- }
- case kIROp_Int8Type:
- case kIROp_Int16Type:
- case kIROp_IntType:
- case kIROp_UInt8Type:
- case kIROp_UInt16Type:
- case kIROp_UIntType:
- case kIROp_Int64Type:
- case kIROp_UInt64Type:
- case kIROp_IntPtrType:
- case kIROp_UIntPtrType:
- {
- // All int like
- return kIROp_IntType;
- }
- case kIROp_HalfType:
- case kIROp_FloatType:
- case kIROp_DoubleType:
- {
- // All float like
- return kIROp_FloatType;
- }
- default: return kIROp_Invalid;
- }
-}
static IROp _getCType(IROp op)
{
@@ -912,7 +880,7 @@ void CPPSourceEmitter::_emitAnyAllDefinition(const UnownedStringSlice& funcName,
IRType* retType = specOp->returnType;
auto retTypeName = _getTypeName(retType);
- IROp style = _getTypeStyle(elementType->getOp());
+ IROp style = getTypeStyle(elementType->getOp());
const TypeDimension dim = _getTypeDimension(paramType0, false);
diff --git a/source/slang/slang-ir-dce.cpp b/source/slang/slang-ir-dce.cpp
index 7d677b488..d58e307da 100644
--- a/source/slang/slang-ir-dce.cpp
+++ b/source/slang/slang-ir-dce.cpp
@@ -361,13 +361,6 @@ bool shouldInstBeLiveIfParentIsLive(IRInst* inst, IRDeadCodeEliminationOptions o
case kIROp_WitnessTableEntry:
return true;
- // Special dictionaries used for differentiable type tracking
- // should be kept alive. These are removed by the auto-diff pass,
- // once they are used.
- case kIROp_DifferentiableTypeDictionaryItem:
- case kIROp_DifferentiableTypeDictionary:
- return true;
-
default:
break;
}
diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp
index 3d02d4fc0..d0bf8f347 100644
--- a/source/slang/slang-ir-diff-jvp.cpp
+++ b/source/slang/slang-ir-diff-jvp.cpp
@@ -6,6 +6,7 @@
#include "slang-ir-clone.h"
#include "slang-ir-dce.h"
#include "slang-ir-eliminate-phis.h"
+#include "slang-ir-util.h"
// origX, primalX, diffX
// origX -> primalX (cloneEnv)
@@ -26,11 +27,9 @@ struct Pair
typedef Pair<IRInst*, IRInst*> InstPair;
-struct DifferentiableTypeConformanceContext
+struct AutoDiffSharedContext
{
- Dictionary<IRInst*, IRInst*> witnessTableMap;
-
- IRInst* inst = nullptr;
+ IRModuleInst* moduleInst = nullptr;
// A reference to the builtin IDifferentiable interface type.
// We use this to look up all the other types (and type exprs)
@@ -62,114 +61,27 @@ struct DifferentiableTypeConformanceContext
//
bool isInterfaceAvailable = false;
- // For handling generic blocks, we use a parent pointer to allow
- // looking up types in all relevant scopes.
- DifferentiableTypeConformanceContext* parent = nullptr;
- DifferentiableTypeConformanceContext(DifferentiableTypeConformanceContext* parent, IRInst* inst) : parent(parent), inst(inst)
+ AutoDiffSharedContext(IRModuleInst* inModuleInst)
+ : moduleInst(inModuleInst)
{
- if (parent)
+ differentiableInterfaceType = as<IRInterfaceType>(findDifferentiableInterface());
+ if (differentiableInterfaceType)
{
- differentiableInterfaceType = parent->differentiableInterfaceType;
- differentialAssocTypeStructKey = parent->differentialAssocTypeStructKey;
- zeroMethodStructKey = parent->zeroMethodStructKey;
- addMethodStructKey = parent->addMethodStructKey;
-
- isInterfaceAvailable = parent->isInterfaceAvailable;
- }
- else
- {
- differentiableInterfaceType = as<IRInterfaceType>(findDifferentiableInterface());
- if (differentiableInterfaceType)
- {
- differentialAssocTypeStructKey = findDifferentialTypeStructKey();
- zeroMethodStructKey = findZeroMethodStructKey();
- addMethodStructKey = findAddMethodStructKey();
-
- if (differentialAssocTypeStructKey)
- isInterfaceAvailable = true;
- }
- }
- }
-
- DifferentiableTypeConformanceContext(IRInst* inst) :
- DifferentiableTypeConformanceContext(nullptr, inst)
- {}
+ differentialAssocTypeStructKey = findDifferentialTypeStructKey();
+ zeroMethodStructKey = findZeroMethodStructKey();
+ addMethodStructKey = findAddMethodStructKey();
- // Lookup a witness table for the concreteType. One should exist if concreteType
- // inherits (successfully) from IDifferentiable.
- //
- IRInst* lookUpConformanceForType(IRBuilder* builder, IRInst* type)
- {
- SLANG_ASSERT(isInterfaceAvailable);
- // TODO: Cache the returned value to avoid repeatedly scanning through
- // blocks looking for the type entries.
- //
- if (auto irWitness = builder->findDifferentiableTypeEntry(type, type->getParent()))
- {
- return irWitness;
- }
-
- return nullptr;
- }
-
- IRInst* lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key)
- {
- if (auto conformance = lookUpConformanceForType(builder, origType))
- {
- if (auto witnessTable = as<IRWitnessTable>(conformance))
- {
- for (auto entry : witnessTable->getEntries())
- {
- if (entry->getRequirementKey() == key)
- return entry->getSatisfyingVal();
- }
- }
- else if (auto witnessTableParam = as<IRParam>(conformance))
- {
- return builder->emitLookupInterfaceMethodInst(
- builder->getTypeKind(),
- witnessTableParam,
- key);
- }
- }
-
- return nullptr;
- }
-
- // Lookup and return the 'Differential' type declared in the concrete type
- // in order to conform to the IDifferentiable interface.
- // Note that inside a generic block, this will be a witness table lookup instruction
- // that gets resolved during the specialization pass.
- //
- IRInst* getDifferentialForType(IRBuilder* builder, IRType* origType)
- {
- switch (origType->getOp())
- {
- case kIROp_FloatType:
- case kIROp_HalfType:
- case kIROp_DoubleType:
- case kIROp_VectorType:
- return origType;
+ if (differentialAssocTypeStructKey)
+ isInterfaceAvailable = true;
}
- return lookUpInterfaceMethod(builder, origType, differentialAssocTypeStructKey);
- }
-
- IRInst* getZeroMethodForType(IRBuilder* builder, IRType* origType)
- {
- return lookUpInterfaceMethod(builder, origType, zeroMethodStructKey);
- }
-
- IRInst* getAddMethodForType(IRBuilder* builder, IRType* origType)
- {
- return lookUpInterfaceMethod(builder, origType, addMethodStructKey);
}
private:
IRInst* findDifferentiableInterface()
{
- if (auto module = as<IRModuleInst>(inst))
+ if (auto module = as<IRModuleInst>(moduleInst))
{
for (auto globalInst : module->getGlobalInsts())
{
@@ -203,7 +115,7 @@ struct DifferentiableTypeConformanceContext
IRStructKey* getIDifferentiableStructKeyAtIndex(UInt index)
{
- if (as<IRModuleInst>(inst) && differentiableInterfaceType)
+ if (as<IRModuleInst>(moduleInst) && differentiableInterfaceType)
{
// Assume for now that IDifferentiable has exactly four fields.
SLANG_ASSERT(differentiableInterfaceType->getOperandCount() == 4);
@@ -217,110 +129,126 @@ struct DifferentiableTypeConformanceContext
return nullptr;
}
+};
- void loadWitnessTablesForInterface(IRInst* interfaceType)
+namespace
+{
+
+IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementKey)
+{
+ if (auto witnessTable = as<IRWitnessTable>(witness))
{
-
- if (auto module = as<IRModuleInst>(inst))
+ for (auto entry : witnessTable->getEntries())
{
- for (auto globalInst : module->getGlobalInsts())
- {
- if (globalInst->getOp() == kIROp_WitnessTable &&
- cast<IRWitnessTableType>(globalInst->getDataType())->getConformanceType() ==
- interfaceType)
- {
- // TODO: Can we have multiple conformances for the same pair of types?
- // TODO: Can type instrs be duplicated (i.e. two different float types)? And if they are duplicated, can
- // we supply the dictionary with a custom equality rule that uses 'type1->equals(type2)'
- witnessTableMap.Add(as<IRWitnessTable>(globalInst)->getConcreteType(), globalInst);
- }
- }
+ if (entry->getRequirementKey() == requirementKey)
+ return entry->getSatisfyingVal();
}
- else if (auto generic = as<IRGeneric>(inst))
- {
- List<IRParam*> typeParams;
+ }
+ else if (auto witnessTableParam = as<IRParam>(witness))
+ {
+ return builder->emitLookupInterfaceMethodInst(
+ builder->getTypeKind(),
+ witnessTableParam,
+ requirementKey);
+ }
+ return nullptr;
+}
+
+}
+
+struct DifferentiableTypeConformanceContext
+{
+ AutoDiffSharedContext* sharedContext;
+
+ IRGlobalValueWithCode* parentFunc = nullptr;
+ Dictionary<IRType*, IRInst*> differentiableWitnessDictionary;
+
+ DifferentiableTypeConformanceContext(AutoDiffSharedContext* shared)
+ : sharedContext(shared)
+ {}
+
+ void setFunc(IRGlobalValueWithCode* func)
+ {
+ parentFunc = func;
+
+ auto decor = func->findDecoration<IRDifferentiableTypeDictionaryDecoration>();
+ SLANG_RELEASE_ASSERT(decor);
- auto genericParam = generic->getFirstParam();
- while (genericParam)
+ // Build lookup dictionary for type witnesses.
+ for (auto child = decor->getFirstChild(); child; child = child->next)
+ {
+ if (auto item = as<IRDifferentiableTypeDictionaryItem>(child))
{
- if (as<IRTypeType>(genericParam->getDataType()))
+ auto existingItem = differentiableWitnessDictionary.TryGetValue(item->getConcreteType());
+ if (existingItem)
{
- typeParams.add(genericParam);
+ if (auto witness = as<IRWitnessTable>(item->getWitness()))
+ {
+ if (witness->getConcreteType()->getOp() == kIROp_DifferentialBottomType)
+ continue;
+ }
+ *existingItem = item->getWitness();
}
else
- break;
-
- genericParam = genericParam->getNextParam();
- }
-
- Count tableIndex = 0;
- while (genericParam)
- {
- SLANG_ASSERT(!as<IRTypeType>(genericParam->getDataType()));
-
- if (tableIndex >= typeParams.getCount())
- break;
-
- if (auto witnessTableType = as<IRWitnessTableType>(genericParam->getDataType()))
{
- // TODO(sai): Heavily flawed way to find the right witness table.
- // Rewrite this part
- if (witnessTableType->getConformanceType() == differentiableInterfaceType)
- witnessTableMap.Add(typeParams[tableIndex], genericParam);
+ differentiableWitnessDictionary.Add((IRType*)item->getConcreteType(), item->getWitness());
}
- else
- break;
-
- tableIndex += 1;
- genericParam = genericParam->getNextParam();
}
-
}
-
}
-};
-
-IRInst* findGlobal(IRInst* inst)
-{
- if (inst->getParent() != inst->getModule()->getModuleInst())
+ // Lookup a witness table for the concreteType. One should exist if concreteType
+ // inherits (successfully) from IDifferentiable.
+ //
+ IRInst* lookUpConformanceForType(IRInst* type)
{
- return findGlobal(inst->getParent());
+ IRInst* foundResult = nullptr;
+ differentiableWitnessDictionary.TryGetValue(type, foundResult);
+ return foundResult;
}
- return inst;
-}
-
-void moveGlobalToBeforeUses(IRBuilder*, IRInst* globalInst)
-{
- HashSet<IRInst*> globalsOfUses;
- for (auto use = globalInst->firstUse; use; use = use->nextUse)
+ IRInst* lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key)
{
- globalsOfUses.Add(findGlobal(use->getUser()));
+ if (auto conformance = lookUpConformanceForType(origType))
+ {
+ return _lookupWitness(builder, conformance, key);
+ }
+ return nullptr;
}
- IRInst* earliestUse = nullptr;
- for (auto cursor = globalInst; cursor; cursor = cursor->getPrevInst())
- {
- if (globalsOfUses.Contains(cursor))
+ // Lookup and return the 'Differential' type declared in the concrete type
+ // in order to conform to the IDifferentiable interface.
+ // Note that inside a generic block, this will be a witness table lookup instruction
+ // that gets resolved during the specialization pass.
+ //
+ IRInst* getDifferentialForType(IRBuilder* builder, IRType* origType)
+ {
+ switch (origType->getOp())
{
- earliestUse = cursor;
+ case kIROp_FloatType:
+ case kIROp_HalfType:
+ case kIROp_DoubleType:
+ case kIROp_VectorType:
+ return origType;
}
+ return lookUpInterfaceMethod(builder, origType, sharedContext->differentialAssocTypeStructKey);
}
- if (earliestUse)
+ IRInst* getZeroMethodForType(IRBuilder* builder, IRType* origType)
{
- globalInst->insertBefore(earliestUse);
+ return lookUpInterfaceMethod(builder, origType, sharedContext->zeroMethodStructKey);
+ }
+
+ IRInst* getAddMethodForType(IRBuilder* builder, IRType* origType)
+ {
+ return lookUpInterfaceMethod(builder, origType, sharedContext->addMethodStructKey);
}
-}
+
+};
struct DifferentialPairTypeBuilder
{
-
- DifferentialPairTypeBuilder(DifferentiableTypeConformanceContext* diffConformanceContext) :
- diffConformanceContext(diffConformanceContext)
- {}
IRStructField* findField(IRInst* type, IRStructKey* key)
{
@@ -454,14 +382,6 @@ struct DifferentialPairTypeBuilder
return emitFieldAccessor(builder, baseInst, this->globalDiffKey);
}
- void relocateNewTypes(IRBuilder* builder)
- {
- for (auto typeInst : generatedTypeList)
- {
- moveGlobalToBeforeUses(builder, typeInst);
- }
- }
-
IRStructKey* _getOrCreateDiffStructKey(IRBuilder* builder)
{
if (!this->globalDiffKey)
@@ -496,27 +416,23 @@ struct DifferentialPairTypeBuilder
return this->globalPrimalKey;
}
- IRInst* _createDiffPairType(IRBuilder* builder, IRType* origBaseType)
+ IRInst* _createDiffPairType(IRBuilder* builder, IRType* origBaseType, IRType* diffType)
{
- if (auto diffBaseType = diffConformanceContext->getDifferentialForType(builder, origBaseType))
- {
- SLANG_ASSERT(!as<IRParam>(origBaseType));
-
- auto pairStructType = builder->createStructType();
- builder->createStructField(pairStructType, _getOrCreatePrimalStructKey(builder), origBaseType);
- builder->createStructField(pairStructType, _getOrCreateDiffStructKey(builder), (IRType*) diffBaseType);
+ SLANG_ASSERT(!as<IRParam>(origBaseType));
+ SLANG_ASSERT(diffType);
+ auto pairStructType = builder->createStructType();
+ builder->createStructField(pairStructType, _getOrCreatePrimalStructKey(builder), origBaseType);
+ builder->createStructField(pairStructType, _getOrCreateDiffStructKey(builder), (IRType*)diffType);
- return pairStructType;
- }
- return nullptr;
+ return pairStructType;
}
- IRInst* getOrCreateDiffPairType(IRBuilder* builder, IRType* origBaseType)
+ IRInst* getOrCreateDiffPairType(IRBuilder* builder, IRType* origBaseType, IRType* diffType)
{
if (pairTypeCache.ContainsKey(origBaseType))
return pairTypeCache[origBaseType];
- auto pairType = _createDiffPairType(builder, origBaseType);
+ auto pairType = _createDiffPairType(builder, origBaseType, diffType);
pairTypeCache.Add(origBaseType, pairType);
return pairType;
@@ -524,8 +440,6 @@ struct DifferentialPairTypeBuilder
Dictionary<IRInst*, IRInst*> pairTypeCache;
- DifferentiableTypeConformanceContext* diffConformanceContext;
-
IRStructKey* globalPrimalKey = nullptr;
IRStructKey* globalDiffKey = nullptr;
@@ -553,11 +467,17 @@ struct JVPTranscriber
DiagnosticSink* sink;
// Type conformance information.
- DifferentiableTypeConformanceContext* diffConformanceContext;
+ AutoDiffSharedContext* autoDiffSharedContext;
// Builder to help with creating and accessing the 'DifferentiablePair<T>' struct
DifferentialPairTypeBuilder* pairBuilder;
+ DifferentiableTypeConformanceContext differentiableTypeConformanceContext;
+
+ JVPTranscriber(AutoDiffSharedContext* shared)
+ : differentiableTypeConformanceContext(shared)
+ {}
+
DiagnosticSink* getSink()
{
SLANG_ASSERT(sink);
@@ -692,7 +612,7 @@ struct JVPTranscriber
{
case kIROp_Param:
if (as<IRTypeType>(primalType->getDataType()))
- return (IRType*)(diffConformanceContext->getDifferentialForType(
+ return (IRType*)(differentiableTypeConformanceContext.getDifferentialForType(
builder,
(IRType*)primalType));
else if (as<IRWitnessTableType>(primalType->getDataType()))
@@ -737,7 +657,7 @@ struct JVPTranscriber
}
default:
- return (IRType*)(diffConformanceContext->getDifferentialForType(builder, (IRType*)primalType));
+ return (IRType*)(differentiableTypeConformanceContext.getDifferentialForType(builder, (IRType*)primalType));
}
}
@@ -753,8 +673,10 @@ struct JVPTranscriber
else
return nullptr;
}
-
- return (IRType*)pairBuilder->getOrCreateDiffPairType(builder, primalType);
+ auto diffType = differentiateType(builder, primalType);
+ if (diffType)
+ return (IRType*)pairBuilder->getOrCreateDiffPairType(builder, primalType, diffType);
+ return nullptr;
}
InstPair transcribeParam(IRBuilder* builder, IRParam* origParam)
@@ -1325,7 +1247,7 @@ struct JVPTranscriber
{
// Since primalType has a corresponding differential type, we can lookup the
// definition for zero().
- auto zeroMethod = this->diffConformanceContext->getZeroMethodForType(builder, primalType);
+ auto zeroMethod = differentiableTypeConformanceContext.getZeroMethodForType(builder, primalType);
SLANG_ASSERT(zeroMethod);
auto emptyArgList = List<IRInst*>();
@@ -1333,6 +1255,11 @@ struct JVPTranscriber
}
else
{
+ if (isScalarIntegerType(primalType))
+ {
+ return builder->getIntValue(primalType, 0);
+ }
+
getSink()->diagnose(primalType->sourceLoc,
Diagnostics::internalCompilerError,
"could not generate zero value for given type");
@@ -1359,17 +1286,6 @@ struct JVPTranscriber
for (auto param = origBlock->getFirstParam(); param; param = param->getNextParam())
this->transcribe(builder, param);
- // Look for the differentiable type dictionary and clone it (and anything else we might need).
- // TODO: This logic might have issues if there are additional instructions (say lookup_interface_requirement)
- // that are operands.
- // TODO: This is currently cloning the global dictionary. Should only clone dictionaries in generic blocks.
- if (auto origDict = builder->findDifferentiableTypeDictionary(origBlock))
- {
- auto clonedDict = cloneInst(&cloneEnv, builder, origDict);
- mapPrimalInst(origDict, clonedDict);
- mapDifferentialInst(origDict, clonedDict);
- }
-
// Then, run through every instruction and use the transcriber to generate the appropriate
// derivative code.
//
@@ -1547,6 +1463,8 @@ struct JVPTranscriber
{
IRFunc* primalFunc = nullptr;
+ differentiableTypeConformanceContext.setFunc(origFunc);
+
auto oldLoc = builder->getInsertLoc();
// If this is a top-level function, there is no need to clone it
@@ -1602,6 +1520,16 @@ struct JVPTranscriber
// Transcribe a generic definition
InstPair transcribeGeneric(IRBuilder* builder, IRGeneric* origGeneric)
{
+ auto innerVal = findInnerMostGenericReturnVal(origGeneric);
+ if (auto innerFunc = as<IRFunc>(innerVal))
+ {
+ differentiableTypeConformanceContext.setFunc(innerFunc);
+ }
+ else
+ {
+ return InstPair(origGeneric, nullptr);
+ }
+
// For now, we assume there's only one generic layer. So this inst must be top level
bool isTopLevel = (as<IRModuleInst>(origGeneric->getParent()) != nullptr);
SLANG_RELEASE_ASSERT(isTopLevel);
@@ -1757,10 +1685,6 @@ struct JVPTranscriber
case kIROp_ifElse:
return transcribeIfElse(builder, as<IRIfElse>(origInst));
- case kIROp_DifferentiableTypeDictionary:
- // Ignore dictionary insts.
- return InstPair(nullptr, nullptr);
-
}
// If none of the cases have been hit, check if the instruction is a
@@ -1885,11 +1809,8 @@ struct JVPDerivativeContext
// IRDifferentialPairGetPrimal with 'primal' field access, and
// IRMakeDifferentialPair with an IRMakeStruct.
//
- modified |= processPairTypes(builder, module->getModuleInst(), (&diffConformanceContextStorage));
+ modified |= processPairTypes(builder, module->getModuleInst());
- // Temporary fix: Move generated types, if any, to before their use locations.
- (&pairBuilderStorage)->relocateNewTypes(builder);
-
return modified;
}
@@ -1981,7 +1902,7 @@ struct JVPDerivativeContext
return true;
}
- IRInst* lowerPairType(IRBuilder* builder, IRType* type, DifferentiableTypeConformanceContext*)
+ IRInst* lowerPairType(IRBuilder* builder, IRType* type)
{
if (auto pairType = as<IRDifferentialPairType>(type))
@@ -1990,13 +1911,18 @@ struct JVPDerivativeContext
if (!as<IRType>(pairType->getValueType()))
{
- // Do not handle non-concrete types.
return nullptr;
}
-
+ auto witness = pairType->getWitness();
+ auto diffType = _lookupWitness(builder, witness, autoDiffSharedContextStorage.differentialAssocTypeStructKey);
+ if (!diffType)
+ {
+ return nullptr;
+ }
auto diffPairStructType = (&pairBuilderStorage)->getOrCreateDiffPairType(
builder,
- pairType->getValueType());
+ pairType->getValueType(),
+ (IRType*)(diffType));
pairType->replaceUsesWith(diffPairStructType);
pairType->removeAndDeallocate();
@@ -2017,12 +1943,12 @@ struct JVPDerivativeContext
return nullptr;
}
- IRInst* lowerMakePair(IRBuilder* builder, IRInst* inst, DifferentiableTypeConformanceContext* diffContext)
+ IRInst* lowerMakePair(IRBuilder* builder, IRInst* inst)
{
if (auto makePairInst = as<IRMakeDifferentialPair>(inst))
{
- if (auto diffPairStructType = lowerPairType(builder, makePairInst->getDataType(), diffContext))
+ if (auto diffPairStructType = lowerPairType(builder, makePairInst->getDataType()))
{
builder->setInsertBefore(makePairInst);
@@ -2041,11 +1967,11 @@ struct JVPDerivativeContext
return nullptr;
}
- IRInst* lowerPairAccess(IRBuilder* builder, IRInst* inst, DifferentiableTypeConformanceContext* diffContext)
+ IRInst* lowerPairAccess(IRBuilder* builder, IRInst* inst)
{
if (auto getDiffInst = as<IRDifferentialPairGetDifferential>(inst))
{
- if (lowerPairType(builder, getDiffInst->getBase()->getDataType(), diffContext))
+ if (lowerPairType(builder, getDiffInst->getBase()->getDataType()))
{
builder->setInsertBefore(getDiffInst);
@@ -2057,7 +1983,7 @@ struct JVPDerivativeContext
}
else if (auto getPrimalInst = as<IRDifferentialPairGetPrimal>(inst))
{
- if (lowerPairType(builder, getPrimalInst->getBase()->getDataType(), diffContext))
+ if (lowerPairType(builder, getPrimalInst->getBase()->getDataType()))
{
builder->setInsertBefore(getPrimalInst);
@@ -2072,16 +1998,10 @@ struct JVPDerivativeContext
return nullptr;
}
- bool processPairTypes(IRBuilder* builder, IRInst* instWithChildren, DifferentiableTypeConformanceContext* diffContext)
+ bool processPairTypes(IRBuilder* builder, IRInst* instWithChildren)
{
bool modified = false;
- // Create a new sub-context to scan witness tables inside workItem
- // (mainly relevant if instWithChildren is a generic scope)
- //
- auto subContext = DifferentiableTypeConformanceContext(diffContext, instWithChildren);
- (&pairBuilderStorage)->diffConformanceContext = (&subContext);
-
for (auto child = instWithChildren->getFirstChild(); child; )
{
// Make sure the builder is at the right level.
@@ -2092,53 +2012,21 @@ struct JVPDerivativeContext
switch (child->getOp())
{
case kIROp_DifferentialPairType:
- lowerPairType(builder, as<IRType>(child), &subContext);
+ lowerPairType(builder, as<IRType>(child));
break;
case kIROp_DifferentialPairGetDifferential:
case kIROp_DifferentialPairGetPrimal:
- lowerPairAccess(builder, child, &subContext);
+ lowerPairAccess(builder, child);
break;
case kIROp_MakeDifferentialPair:
- lowerMakePair(builder, child, &subContext);
+ lowerMakePair(builder, child);
break;
default:
if (child->getFirstChild())
- modified = processPairTypes(builder, child, (&subContext)) | modified;
- }
-
- child = nextChild;
- }
-
- // Reset the context back to the parent.
- (&pairBuilderStorage)->diffConformanceContext = diffContext;
-
- return modified;
- }
-
- bool stripDiffTypeInformation(IRInst* parent)
- {
- bool modified = false;
-
- auto child = parent->getFirstChild();
- while (child)
- {
- auto nextChild = child->getNextInst();
-
- switch (child->getOp())
- {
- case kIROp_DifferentiableTypeDictionary:
- child->removeAndDeallocate();
- child = nextChild;
- modified = true;
- continue;
- }
-
- if (child->getFirstChild() != nullptr)
- {
- modified |= stripDiffTypeInformation(child);
+ modified = processPairTypes(builder, child) | modified;
}
child = nextChild;
@@ -2186,12 +2074,13 @@ struct JVPDerivativeContext
}
JVPDerivativeContext(IRModule* module, DiagnosticSink* sink) :
- module(module), sink(sink),
- diffConformanceContextStorage(module->getModuleInst()),
- pairBuilderStorage(&diffConformanceContextStorage)
+ module(module),
+ sink(sink),
+ autoDiffSharedContextStorage(module->getModuleInst()),
+ transcriberStorage(&autoDiffSharedContextStorage)
{
transcriberStorage.sink = sink;
- transcriberStorage.diffConformanceContext = &(diffConformanceContextStorage);
+ transcriberStorage.autoDiffSharedContext = &(autoDiffSharedContextStorage);
transcriberStorage.pairBuilder = &(pairBuilderStorage);
}
@@ -2221,7 +2110,7 @@ struct JVPDerivativeContext
// Context to find and manage the witness tables for types
// implementing `IDifferentiable`
- DifferentiableTypeConformanceContext diffConformanceContextStorage;
+ AutoDiffSharedContext autoDiffSharedContextStorage;
// Builder for dealing with differential pair types.
DifferentialPairTypeBuilder pairBuilderStorage;
@@ -2243,7 +2132,6 @@ bool processForwardDifferentiableFuncs(
JVPDerivativeContext context(module, sink);
bool changed = context.processModule();
- changed |= context.stripDiffTypeInformation(module->getModuleInst());
return changed;
}
@@ -2258,6 +2146,7 @@ void stripAutoDiffDecorationsFromChildren(IRInst* parent)
{
case kIROp_ForwardDerivativeDecoration:
case kIROp_DerivativeMemberDecoration:
+ case kIROp_DifferentiableTypeDictionaryDecoration:
decor->removeAndDeallocate();
break;
default:
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index 61aa28bbe..431446f01 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -715,6 +715,9 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0)
/// the witness table to be easily picked up by emit.
INST(COMWitnessDecoration, COMWitnessDecoration, 1, 0)
+ /* Differentiable Type Dictionary */
+ INST(DifferentiableTypeDictionaryDecoration, DifferentiableTypeDictionaryDecoration, 0, PARENT)
+
/// Marks a struct type as being used as a structured buffer block.
/// Recognized by SPIRV-emit pass so we can emit a SPIRV `BufferBlock` decoration.
INST(SPIRVBufferBlockDecoration, spvBufferBlock, 0, 0)
@@ -812,7 +815,6 @@ INST(ExistentialFuncSpecializationDictionary, ExistentialFuncSpecializationDicti
INST(ExistentialTypeSpecializationDictionary, ExistentialTypeSpecializationDictionary, 0, PARENT)
/* Differentiable Type Dictionary */
-INST(DifferentiableTypeDictionary, DifferentiableTypeDictionary, 0, PARENT)
INST(DifferentiableTypeDictionaryItem, DifferentiableTypeDictionaryItem, 0, 0)
#undef PARENT
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index deb81134b..989777944 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -598,6 +598,14 @@ struct IRForwardDifferentiate : IRInst
struct IRDifferentiableTypeDictionaryItem : IRInst
{
IR_LEAF_ISA(DifferentiableTypeDictionaryItem)
+
+ IRInst* getConcreteType() { return getOperand(0); }
+ IRInst* getWitness() { return getOperand(1); }
+};
+
+struct IRDifferentiableTypeDictionaryDecoration : IRInst
+{
+ IR_LEAF_ISA(DifferentiableTypeDictionaryDecoration)
};
@@ -2490,26 +2498,10 @@ public:
IRInst* emitMakeDifferentialPair(IRType* type, IRInst* primal, IRInst* differential);
- // Emit and return a dictionary instruction to the global or generic scope.
- IRInst* emitDifferentiableTypeDictionary();
-
- // Emit and return a dictionary instruction to the global or generic scope,
- // if one is not already present.
- //
- IRInst* findOrEmitDifferentiableTypeDictionary();
-
- // Returns the IRDifferentiableTypeDictionary in the scope of inst.
- IRInst* findDifferentiableTypeDictionary(IRInst* inst);
+ IRInst* addDifferentiableTypeDictionaryDecoration(IRInst* target);
// Add a differentiable type entry to the appropriate dictionary.
- IRInst* addDifferentiableTypeEntry(IRInst* irType, IRInst* conformanceWitness);
-
- // Lookup a differentiable type entry in the appropriate dictionary.
- // This recursively looks up in upper contexts.
- //
- IRInst* findDifferentiableTypeEntry(IRInst* irType);
-
- IRInst* findDifferentiableTypeEntry(IRInst* irType, IRInst* scope);
+ IRInst* addDifferentiableTypeEntry(IRInst* dictDecoration, IRInst* irType, IRInst* conformanceWitness);
IRInst* emitSpecializeInst(
IRType* type,
diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp
index eb899b69c..ad4f691f1 100644
--- a/source/slang/slang-ir-link.cpp
+++ b/source/slang/slang-ir-link.cpp
@@ -238,7 +238,6 @@ IRInst* IRSpecContext::maybeCloneValue(IRInst* originalValue)
case kIROp_WitnessTable:
case kIROp_InterfaceType:
case kIROp_TaggedUnionType:
- case kIROp_DifferentiableTypeDictionary:
return cloneGlobalValue(this, originalValue);
case kIROp_BoolLit:
@@ -593,24 +592,6 @@ IRWitnessTable* cloneWitnessTableImpl(
return clonedTable;
}
-IRInst* cloneDifferentiableTypeDictionary(
- IRSpecContextBase* context,
- IRBuilder* builder,
- IRInst* originalDict,
- IROriginalValuesForClone const& originalValues,
- IRInst* dstDict = nullptr,
- bool registerValue = true)
-{
- IRInst* clonedDict = dstDict;
- if (!clonedDict)
- {
- clonedDict = builder->emitDifferentiableTypeDictionary();
- }
- cloneSimpleGlobalValueImpl(context, originalDict, originalValues, clonedDict, registerValue);
- return clonedDict;
-}
-
-
IRWitnessTable* cloneWitnessTableWithoutRegistering(
IRSpecContextBase* context,
IRBuilder* builder,
@@ -1138,9 +1119,6 @@ IRInst* cloneInst(
case kIROp_GlobalGenericParam:
return cloneGlobalGenericParamImpl(context, builder, cast<IRGlobalGenericParam>(originalInst), originalValues);
- case kIROp_DifferentiableTypeDictionary:
- return cloneDifferentiableTypeDictionary(context, builder, originalInst, originalValues);
-
default:
break;
}
@@ -1164,9 +1142,8 @@ IRInst* cloneInst(
}
builder->addInst(clonedInst);
context->builder = oldBuilder;
- cloneDecorations(context, clonedInst, originalInst);
+ cloneDecorationsAndChildren(context, clonedInst, originalInst);
cloneExtraDecorations(context, clonedInst, originalValues);
-
return clonedInst;
}
@@ -1530,10 +1507,6 @@ LinkedIR linkIR(
{
cloneValue(context, bindInst);
}
- else if (inst->getOp() == kIROp_DifferentiableTypeDictionary)
- {
- cloneValue(context, inst);
- }
}
}
diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp
index 1f13eb754..214f10ef9 100644
--- a/source/slang/slang-ir-util.cpp
+++ b/source/slang/slang-ir-util.cpp
@@ -66,5 +66,38 @@ bool isComInterfaceType(IRType* type)
return false;
}
+IROp getTypeStyle(IROp op)
+{
+ switch (op)
+ {
+ case kIROp_VoidType:
+ case kIROp_BoolType:
+ {
+ return op;
+ }
+ case kIROp_Int8Type:
+ case kIROp_Int16Type:
+ case kIROp_IntType:
+ case kIROp_UInt8Type:
+ case kIROp_UInt16Type:
+ case kIROp_UIntType:
+ case kIROp_Int64Type:
+ case kIROp_UInt64Type:
+ case kIROp_IntPtrType:
+ case kIROp_UIntPtrType:
+ {
+ // All int like
+ return kIROp_IntType;
+ }
+ case kIROp_HalfType:
+ case kIROp_FloatType:
+ case kIROp_DoubleType:
+ {
+ // All float like
+ return kIROp_FloatType;
+ }
+ default: return kIROp_Invalid;
+ }
+}
}
diff --git a/source/slang/slang-ir-util.h b/source/slang/slang-ir-util.h
index 2300c929d..b6690a28c 100644
--- a/source/slang/slang-ir-util.h
+++ b/source/slang/slang-ir-util.h
@@ -24,6 +24,14 @@ Dictionary<IRInst*, IRInst*> buildInterfaceRequirementDict(IRInterfaceType* inte
bool isComInterfaceType(IRType* type);
+
+IROp getTypeStyle(IROp op);
+
+inline bool isScalarIntegerType(IRType* type)
+{
+ return getTypeStyle(type->getOp()) == kIROp_IntType;
+}
+
}
#endif
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 382f7be5e..083ef98c5 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -3546,133 +3546,35 @@ namespace Slang
value->insertAtEnd(parent);
}
}
-
- IRInst* IRBuilder::emitDifferentiableTypeDictionary()
- {
- auto inst = createInst<IRInst>(
- this,
- kIROp_DifferentiableTypeDictionary,
- nullptr);
-
- addGlobalValue(this, inst);
- return inst;
- }
-
- IRInst* IRBuilder::findOrEmitDifferentiableTypeDictionary()
+ IRInst* IRBuilder::addDifferentiableTypeDictionaryDecoration(IRInst* target)
{
- auto currentLoc = this->getInsertLoc();
- auto currentInst = currentLoc.getInst();
-
- if (auto diffTypeDictionary = findDifferentiableTypeDictionary(currentInst))
- return diffTypeDictionary;
-
- return emitDifferentiableTypeDictionary();
+ return addDecoration(target, kIROp_DifferentiableTypeDictionaryDecoration);
}
- IRInst* IRBuilder::findDifferentiableTypeDictionary(IRInst* parent)
- {
- //auto parent = inst->getParent();
- while (parent)
- {
- // Inserting into the top level of a module?
- // That is fine, and we can stop searching.
- if (as<IRModuleInst>(parent))
- break;
-
- // Inserting into a basic block inside of
- // a generic? That is okay too.
- if (auto block = as<IRBlock>(parent))
- {
- if (as<IRGeneric>(block->parent))
- break;
- }
-
- // Otherwise, move up the chain.
- parent = parent->parent;
- }
-
- for (auto child = parent->getFirstChild(); child; child = child->getNextInst())
- {
- if (child->getOp() == kIROp_DifferentiableTypeDictionary)
- return child;
- }
-
- return nullptr;
- }
-
- IRInst* IRBuilder::addDifferentiableTypeEntry(IRInst* irType, IRInst* conformanceWitness)
+ IRInst* IRBuilder::addDifferentiableTypeEntry(IRInst* dictDecoration, IRInst* irType, IRInst* conformanceWitness)
{
auto oldLoc = this->getInsertLoc();
IRDifferentiableTypeDictionaryItem* item = nullptr;
- if (auto diffTypeDictionary = findOrEmitDifferentiableTypeDictionary())
- {
- this->setInsertInto(diffTypeDictionary);
+ this->setInsertInto(dictDecoration);
- IRInst* args[2] = {irType, conformanceWitness};
- item = createInstWithTrailingArgs<IRDifferentiableTypeDictionaryItem>(
- this,
- kIROp_DifferentiableTypeDictionaryItem,
- nullptr,
- 2,
- args);
+ IRInst* args[2] = {irType, conformanceWitness};
+ item = createInstWithTrailingArgs<IRDifferentiableTypeDictionaryItem>(
+ this,
+ kIROp_DifferentiableTypeDictionaryItem,
+ nullptr,
+ 2,
+ args);
- addInst(item);
- }
+ addInst(item);
this->setInsertLoc(oldLoc);
return item;
}
- IRInst* IRBuilder::findDifferentiableTypeEntry(IRInst* irType, IRInst* scope)
- {
- IRInst* foundResult = nullptr;
- for (auto child = scope->getFirstChild(); child; child = child->getNextInst())
- {
- if (child->getOp() == kIROp_DifferentiableTypeDictionary)
- {
- for (auto entry = child->getFirstChild(); entry; entry = entry->getNextInst())
- {
- IRInst* entryType = entry->getOperand(0);
- IRInst* entryConformanceWitness = entry->getOperand(1);
-
- if (irType == entryType)
- {
- foundResult = entryConformanceWitness;
- // If the found witness table is not a trivial one (i.e. DifferentialBottom:IDifferential),
- // return immediately. Otherwise, continue the search to see if we can find a better one.
- if (auto witness = as<IRWitnessTable>(foundResult))
- {
- if (witness->getConcreteType()->getOp() != kIROp_DifferentialBottomType)
- return foundResult;
- }
- }
- }
- }
- }
-
- return foundResult;
- }
-
- IRInst* IRBuilder::findDifferentiableTypeEntry(IRInst* irType)
- {
- auto instScope = this->getInsertLoc().getInst();
-
- while (instScope)
- {
- if (auto witness = findDifferentiableTypeEntry(irType, instScope))
- {
- return witness;
- }
- instScope = instScope->getParent();
- }
-
- return nullptr;
- }
-
IRFunc* IRBuilder::createFunc()
{
diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h
index a2fb1be98..9295ca2f5 100644
--- a/source/slang/slang-ir.h
+++ b/source/slang/slang-ir.h
@@ -1323,6 +1323,7 @@ SIMPLE_IR_TYPE(GenericKind, Kind)
struct IRDifferentialPairType : IRType
{
IRType* getValueType() { return (IRType*)getOperand(0); }
+ IRInst* getWitness() { return (IRInst*)getOperand(1); }
IR_LEAF_ISA(DifferentialPairType)
};
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index e2b14f1e3..f8d8282d8 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -5866,47 +5866,6 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
return LoweredValInfo();
}
- LoweredValInfo visitDifferentiableTypeDictionary(DifferentiableTypeDictionary* decl)
- {
- for (auto & member : decl->members)
- {
- if (auto entry = as<DifferentiableTypeDictionaryItem>(member))
- {
-
- // Lower type and witness.
- IRType* irType = lowerType(context, entry->baseType);
- IRInst* irWitness = lowerVal(context, entry->confWitness).val;
-
- SLANG_ASSERT(irType);
-
- // If the witness can be lowered, and the differentiable type entry exists,
- // add an entry to the context.
- //
- if (irWitness && !getBuilder()->findDifferentiableTypeEntry(irType))
- {
- getBuilder()->addDifferentiableTypeEntry(irType, irWitness);
- }
- }
- else if (auto importEntry = as<DifferentiableTypeDictionaryImportItem>(member))
- {
- ensureDecl(context, importEntry->dictionaryRef.getDecl());
- }
- else
- {
- SLANG_UNEXPECTED("Unrecognized item in DifferentiableTypeDictionary");
- UNREACHABLE_RETURN(LoweredValInfo());
- }
- }
-
- if (auto diffTypeDict = getBuilder()->findOrEmitDifferentiableTypeDictionary())
- {
- // Place the dictionary at the end of modules and generic blocks.
- diffTypeDict->moveToEnd();
- }
-
- return LoweredValInfo();
- }
-
#define IGNORED_CASE(NAME) \
LoweredValInfo visit##NAME(NAME*) { return LoweredValInfo(); }
@@ -5916,7 +5875,6 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
IGNORED_CASE(SyntaxDecl)
IGNORED_CASE(AttributeDecl)
IGNORED_CASE(NamespaceDecl)
- IGNORED_CASE(DifferentiableTypeDictionaryItem)
#undef IGNORED_CASE
@@ -7119,6 +7077,27 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
builder->addDecoration(inst, kIROp_DerivativeMemberDecoration, key);
}
+ void lowerDifferentiableAttribute(IRGenContext* subContext, IRInst* inst, DifferentiableAttribute* attr)
+ {
+ auto irDict = getBuilder()->addDifferentiableTypeDictionaryDecoration(inst);
+ for (auto& entry : attr->m_mapTypeToIDifferentiableWitness)
+ {
+ // Lower type and witness.
+ IRType* irType = lowerType(subContext, entry.Value->sub);
+ IRInst* irWitness = lowerVal(subContext, entry.Value).val;
+
+ SLANG_ASSERT(irType);
+
+ // If the witness can be lowered, and the differentiable type entry exists,
+ // add an entry to the context.
+ //
+ if (irWitness)
+ {
+ getBuilder()->addDifferentiableTypeEntry(irDict, irType, irWitness);
+ }
+ }
+ }
+
LoweredValInfo lowerMemberVarDecl(VarDecl* fieldDecl)
{
// Each field declaration in the AST translates into
@@ -7170,14 +7149,9 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
// output for the chosen target.
addTargetIntrinsicDecorations(irFieldKey, fieldDecl);
-
return LoweredValInfo::simple(irFieldKey);
}
-
-
-
-
bool isImportedDecl(Decl* decl)
{
return Slang::isImportedDecl(context, decl);
@@ -7196,6 +7170,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
GenericTypeConstraintDecl* constraintDecl,
IRType* supType)
{
+
auto subBuilder = subContext->irBuilder;
// There are two cases we care about here.
@@ -7311,21 +7286,6 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
}
}
- // We only need dictionaries to be lowered for decls with executable code (i.e. statements)
- // Do not lower type dictionaries for inhertiance decls or decls
- // that are declaring a type, since this can create a cyclic dependancy.
- //
- if (as<FunctionDeclBase>(leafDecl))
- {
- for (auto diffTypeDict : genericDecl->getMembersOfType<DifferentiableTypeDictionary>())
- {
- // We directly use lowerDecl() instead of ensureDecl() to emit to
- // the current generic block instead of the top-level module.
- //
- lowerDecl(subContext, diffTypeDict);
- }
- }
-
return irGeneric;
}
@@ -7479,10 +7439,6 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
{
markInstsToClone(valuesToClone, parentGeneric->getFirstBlock(), genericParam);
}
-
- // Add a differentiable type dictionary if necessary.
- if (auto diffTypeDict = subBuilder->findDifferentiableTypeDictionary(parentGeneric->getFirstBlock()))
- markInstsToClone(valuesToClone, parentGeneric->getFirstBlock(), diffTypeDict);
}
if (valuesToClone.Count() == 0)
{
@@ -7838,6 +7794,10 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
{
getBuilder()->addForwardDifferentiableDecoration(irFunc);
}
+ if (auto differentialAttr = decl->findModifier<DifferentiableAttribute>())
+ {
+ lowerDifferentiableAttribute(subContext, irFunc, differentialAttr);
+ }
// Always force inline diff setter accessor to prevent downstream compiler from complaining
// fields are not fully initialized for the first `inout` parameter.
diff --git a/source/slang/slang-serialize-ast-type-info.h b/source/slang/slang-serialize-ast-type-info.h
index 937ecc95f..0412ef4da 100644
--- a/source/slang/slang-serialize-ast-type-info.h
+++ b/source/slang/slang-serialize-ast-type-info.h
@@ -83,6 +83,9 @@ struct SerialGetFieldType<DeclRef<T>>
template <typename T>
struct SerialTypeInfo<DeclRef<T>> : public SerialDeclRefBaseTypeInfo {};
+template<>
+struct SerialTypeInfo<DeclRefBase> : public SerialDeclRefBaseTypeInfo {};
+
// MatrixCoord can just go as is
template <>
struct SerialTypeInfo<MatrixCoord> : SerialIdentityTypeInfo<MatrixCoord> {};