summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
Diffstat (limited to 'source')
-rw-r--r--source/slang/core.meta.slang59
-rw-r--r--source/slang/diff.meta.slang58
-rw-r--r--source/slang/slang-ast-builder.cpp5
-rw-r--r--source/slang/slang-ast-builder.h8
-rw-r--r--source/slang/slang-ast-modifier.h17
-rw-r--r--source/slang/slang-ast-support-types.cpp6
-rw-r--r--source/slang/slang-ast-support-types.h1
-rw-r--r--source/slang/slang-check-decl.cpp137
-rw-r--r--source/slang/slang-check-expr.cpp2
-rw-r--r--source/slang/slang-check-impl.h16
-rw-r--r--source/slang/slang-check-modifier.cpp21
-rw-r--r--source/slang/slang-diagnostic-defs.h3
-rw-r--r--source/slang/slang-emit.cpp3
-rw-r--r--source/slang/slang-ir-cleanup-void.cpp49
-rw-r--r--source/slang/slang-ir-diff-jvp.cpp835
-rw-r--r--source/slang/slang-ir-inst-pass-base.h24
-rw-r--r--source/slang/slang-ir-insts.h8
-rw-r--r--source/slang/slang-ir.cpp50
-rw-r--r--source/slang/slang-ir.h2
-rw-r--r--source/slang/slang-parser.cpp12
-rw-r--r--source/slang/slang-syntax.cpp2
21 files changed, 924 insertions, 394 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index ce52dbb56..05963bd11 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -105,6 +105,55 @@ interface __BuiltinSignedArithmeticType : __BuiltinArithmeticType {}
interface __BuiltinIntegerType : __BuiltinArithmeticType
{}
+
+/// Interface to denote types as differentiable.
+/// Allows for user-specified differential types as
+/// well as automatic generation, for when the associated type
+/// hasn't been declared explicitly.
+/// Note that the requirements must currently be defined in this exact order
+/// since the auto-diff pass relies on the order to grab the struct keys.
+///
+__magic_type(DifferentiableType)
+interface IDifferentiable
+{
+ // Note: the compiler implementation requires the `Differential` associated type to be defined
+ // before anything else.
+
+ __builtin_requirement($( (int) BuiltinRequirementKind::DifferentialType) )
+ associatedtype Differential : IDifferentiable;
+
+ __builtin_requirement($( (int)BuiltinRequirementKind::DZeroFunc) )
+ static Differential dzero();
+
+ __builtin_requirement($( (int)BuiltinRequirementKind::DAddFunc) )
+ static Differential dadd(Differential, Differential);
+
+ __builtin_requirement($( (int)BuiltinRequirementKind::DMulFunc) )
+ static Differential dmul(This, Differential);
+};
+
+__magic_type(DifferentialBottomType)
+__intrinsic_type($(kIROp_DifferentialBottomType))
+struct __DifferentialBottom : IDifferentiable
+{
+ typedef __DifferentialBottom Differential;
+
+ __intrinsic_op($(kIROp_DifferentialBottomValue))
+ static __DifferentialBottom dzero();
+
+ [__unsafeForceInlineEarly]
+ static __DifferentialBottom dadd(Differential a, Differential b)
+ {
+ return dzero();
+ }
+
+ [__unsafeForceInlineEarly]
+ static __DifferentialBottom dmul(This a, Differential b)
+ {
+ return dzero();
+ }
+}
+
/// A type that can represent non-integers
[sealed]
[builtin]
@@ -2739,16 +2788,6 @@ attribute_syntax [Specialize] : SpecializeAttribute;
__attributeTarget(DeclBase)
attribute_syntax [DerivativeMember(memberName)] : DerivativeMemberAttribute;
-enum _BuiltinRequirementKind
-{
- DifferentialType = $( (int) BuiltinRequirementKind::DifferentialType),
- DZeroFunc = $( (int) BuiltinRequirementKind::DZeroFunc),
- DAddFunc = $( (int) BuiltinRequirementKind::DAddFunc),
- DMulFunc = $( (int) BuiltinRequirementKind::DMulFunc),
-};
-__attributeTarget(DeclBase)
-attribute_syntax [__BuiltinRequirement(kind: _BuiltinRequirementKind)] : BuiltinRequirementAttribute;
-
__attributeTarget(DeclBase)
attribute_syntax [builtin] : BuiltinAttribute;
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang
index 1c3066e1d..ae4db603e 100644
--- a/source/slang/diff.meta.slang
+++ b/source/slang/diff.meta.slang
@@ -9,32 +9,6 @@ attribute_syntax [ForwardDifferentiable] : ForwardDifferentiableAttribute;
__attributeTarget(FunctionDeclBase)
attribute_syntax [ForwardDerivative(function)] : ForwardDerivativeAttribute;
-/// Interface to denote types as differentiable.
-/// Allows for user-specified differential types as
-/// well as automatic generation, for when the associated type
-/// hasn't been declared explicitly.
-/// Note that the requirements must currently be defined in this exact order
-/// since the auto-diff pass relies on the order to grab the struct keys.
-///
-__magic_type(DifferentiableType)
-interface IDifferentiable
-{
- // Note: the compiler implementation requires the `Differential` associated type to be defined
- // before anything else.
-
- [__BuiltinRequirement(_BuiltinRequirementKind.DifferentialType)]
- associatedtype Differential;
-
- [__BuiltinRequirement(_BuiltinRequirementKind.DZeroFunc)]
- static Differential dzero();
-
- [__BuiltinRequirement(_BuiltinRequirementKind.DAddFunc)]
- static Differential dadd(Differential, Differential);
-
- [__BuiltinRequirement(_BuiltinRequirementKind.DMulFunc)]
- static Differential dmul(This, Differential);
-};
-
// Add extensions for the standard types
extension float : IDifferentiable
{
@@ -83,28 +57,6 @@ extension vector<float, N> : IDifferentiable
}
}
-__magic_type(DifferentialBottomType)
-__intrinsic_type($(kIROp_DifferentialBottomType))
-struct __DifferentialBottom : IDifferentiable
-{
- typedef __DifferentialBottom Differential;
-
- __intrinsic_op($(kIROp_DifferentialBottomValue))
- static __DifferentialBottom dzero();
-
- [__unsafeForceInlineEarly]
- static __DifferentialBottom dadd(Differential a, Differential b)
- {
- return dzero();
- }
-
- [__unsafeForceInlineEarly]
- static __DifferentialBottom dmul(This a, Differential b)
- {
- return dzero();
- }
-}
-
/// Pair type that serves to wrap the primal and
/// differential types of an arbitrary type T.
__generic<T : IDifferentiable>
@@ -121,6 +73,7 @@ struct DifferentialPair : IDifferentiable
__intrinsic_op($(kIROp_DifferentialPairGetDifferential))
T.Differential d();
+ [__unsafeForceInlineEarly]
T.Differential getDifferential()
{
return d();
@@ -129,6 +82,7 @@ struct DifferentialPair : IDifferentiable
__intrinsic_op($(kIROp_DifferentialPairGetPrimal))
T p();
+ [__unsafeForceInlineEarly]
T getPrimal()
{
return p();
@@ -137,7 +91,7 @@ struct DifferentialPair : IDifferentiable
[__unsafeForceInlineEarly]
static Differential dzero()
{
- return Differential(T.dzero(), Differential.DifferentialElementType.dzero());
+ return Differential(T.dzero(), T.Differential.dzero());
}
[__unsafeForceInlineEarly]
@@ -148,15 +102,15 @@ struct DifferentialPair : IDifferentiable
a.p(),
b.p()
),
- Differential.DifferentialElementType.dzero());
+ T.Differential.dadd(a.d(), b.d()));
}
[__unsafeForceInlineEarly]
static Differential dmul(This a, Differential b)
{
return Differential(
- T.dmul(a.p(), b.p()),
- Differential.DifferentialElementType.dzero());
+ T.dmul(a.p(), b.p()),
+ T.Differential.dmul(a.d(), b.d()));
}
};
diff --git a/source/slang/slang-ast-builder.cpp b/source/slang/slang-ast-builder.cpp
index beee16f9c..6249d7825 100644
--- a/source/slang/slang-ast-builder.cpp
+++ b/source/slang/slang-ast-builder.cpp
@@ -171,6 +171,11 @@ void SharedASTBuilder::registerBuiltinDecl(Decl* decl, BuiltinTypeModifier* modi
m_builtinTypes[Index(modifier->tag)] = type;
}
+void SharedASTBuilder::registerBuiltinRequirementDecl(Decl* decl, BuiltinRequirementModifier* modifier)
+{
+ m_builtinRequirementDecls[modifier->kind] = decl;
+}
+
void SharedASTBuilder::registerMagicDecl(Decl* decl, MagicTypeModifier* modifier)
{
// In some cases the modifier will have been applied to the
diff --git a/source/slang/slang-ast-builder.h b/source/slang/slang-ast-builder.h
index 235bebfaa..190e3727d 100644
--- a/source/slang/slang-ast-builder.h
+++ b/source/slang/slang-ast-builder.h
@@ -19,6 +19,7 @@ class SharedASTBuilder : public RefObject
public:
void registerBuiltinDecl(Decl* decl, BuiltinTypeModifier* modifier);
+ void registerBuiltinRequirementDecl(Decl* decl, BuiltinRequirementModifier* modifier);
void registerMagicDecl(Decl* decl, MagicTypeModifier* modifier);
/// Get the string type
@@ -49,6 +50,11 @@ public:
Decl* tryFindMagicDecl(String const& name);
+ Decl* findBuiltinRequirementDecl(BuiltinRequirementKind kind)
+ {
+ return m_builtinRequirementDecls[kind].GetValue();
+ }
+
/// A name pool that can be used for lookup for findClassInfo etc. It is the same pool as the Session.
NamePool* getNamePool() { return m_namePool; }
@@ -85,6 +91,7 @@ protected:
Type* m_builtinTypes[Index(BaseType::CountOf)];
Dictionary<String, Decl*> m_magicDecls;
+ Dictionary<BuiltinRequirementKind, Decl*> m_builtinRequirementDecls;
Dictionary<UnownedStringSlice, const ReflectClassInfo*> m_sliceToTypeMap;
Dictionary<Name*, const ReflectClassInfo*> m_nameToTypeMap;
@@ -334,6 +341,7 @@ public:
Witness* primalIsDifferentialWitness);
DeclRef<InterfaceDecl> getDifferentiableInterface();
+ Decl* getDifferentiableAssociatedTypeRequirement();
bool isDifferentiableInterfaceAvailable();
diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h
index 67ff297dc..57dfbac9e 100644
--- a/source/slang/slang-ast-modifier.h
+++ b/source/slang/slang-ast-modifier.h
@@ -395,6 +395,15 @@ class MagicTypeModifier : public Modifier
uint32_t tag = uint32_t(0);
};
+// A modifier that indicates a built-in associated type requirement (e.g., `Differential`)
+class BuiltinRequirementModifier : public Modifier
+{
+ SLANG_AST_CLASS(BuiltinRequirementModifier);
+
+ BuiltinRequirementKind kind;
+};
+
+
// A modifier applied to declarations of builtin types to indicate how they
// should be lowered to the IR.
//
@@ -590,14 +599,6 @@ class Attribute : public AttributeBase
AttributeArgumentValueDict intArgVals;
};
-// A modifier that indicates a built-in associated type requirement (e.g., `Differential`)
-class BuiltinRequirementAttribute : public Attribute
-{
- SLANG_AST_CLASS(BuiltinRequirementAttribute);
-
- BuiltinRequirementKind kind;
-};
-
class UserDefinedAttribute : public Attribute
{
SLANG_AST_CLASS(UserDefinedAttribute)
diff --git a/source/slang/slang-ast-support-types.cpp b/source/slang/slang-ast-support-types.cpp
index badb524bb..7133f2a65 100644
--- a/source/slang/slang-ast-support-types.cpp
+++ b/source/slang/slang-ast-support-types.cpp
@@ -2,7 +2,9 @@
#include "slang-ast-base.h"
#include "slang-ast-type.h"
-Slang::QualType::QualType(Type* type)
+namespace Slang
+{
+QualType::QualType(Type* type)
: type(type)
, isLeftValue(false)
{
@@ -11,3 +13,5 @@ Slang::QualType::QualType(Type* type)
isLeftValue = true;
}
}
+
+}
diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h
index 015e6969c..d4a781846 100644
--- a/source/slang/slang-ast-support-types.h
+++ b/source/slang/slang-ast-support-types.h
@@ -1510,6 +1510,7 @@ namespace Slang
DAddFunc, ///< The `IDifferentiable.dadd` function requirement
DMulFunc, ///< The `IDifferentiable.dmul` function requirement
};
+
} // namespace Slang
#endif
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 7140d541a..333e9d973 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -254,6 +254,8 @@ namespace Slang
void visitFunctionDeclBase(FunctionDeclBase* funcDecl);
void visitParamDecl(ParamDecl* paramDecl);
+
+ void _maybeRegisterDifferentialBottomTypeConformance(SemanticsContext& context);
};
/// Should the given `decl` nested in `parentDecl` be treated as a static rather than instance declaration?
@@ -1433,6 +1435,22 @@ namespace Slang
synth.pushScopeForContainer(aggTypeDecl);
}
+ // If `This` is nested inside a generic, we need to form a complete declref type to the
+ // newly synthesized aggTypeDecl here. This can be done by obtaining ThisTypeSubstitution
+ // from requirementDeclRef to get the generic substitution for outer generic parameters, and
+ // apply it to the newly synthesized decl.
+ SubstitutionSet substSet;
+ if (auto thisTypeSusbt = findThisTypeSubstitution(
+ requirementDeclRef.substitutions,
+ as<InterfaceDecl>(requirementDeclRef.getDecl()->parentDecl)))
+ {
+ if (auto declRefType = as<DeclRefType>(thisTypeSusbt->witness->sub))
+ {
+ substSet = declRefType->declRef.substitutions;
+ }
+ }
+ auto satisfyingType = m_astBuilder->getOrCreateDeclRefType(aggTypeDecl, substSet);
+
// Helper function to add a `diffType` field into the synthesized type for the original
// `member`.
auto differentialType = DeclRefType::create(m_astBuilder, makeDeclRef(aggTypeDecl));
@@ -1462,6 +1480,22 @@ namespace Slang
addModifier(member, derivativeMemberModifier);
};
+ // Make the Differential type itself conform to `IDifferential` interface.
+ auto inheritanceIDiffernetiable = m_astBuilder->create<InheritanceDecl>();
+ inheritanceIDiffernetiable->base.type =
+ DeclRefType::create(m_astBuilder, m_astBuilder->getDifferentiableInterface());
+ inheritanceIDiffernetiable->parentDecl = aggTypeDecl;
+ aggTypeDecl->members.add(inheritanceIDiffernetiable);
+
+ // The `Differential` type of a `Differential` type is always itself.
+ auto assocTypeDef = m_astBuilder->create<TypeDefDecl>();
+ assocTypeDef->nameAndLoc.name = getName("Differential");
+ assocTypeDef->type.type = satisfyingType;
+ assocTypeDef->parentDecl = aggTypeDecl;
+ assocTypeDef->setCheckState(DeclCheckState::Checked);
+ aggTypeDecl->members.add(assocTypeDef);
+
+ // Go through all members and collect their differential types.
// Go through super types.
for (auto inheritance : context->parentDecl->getMembersOfType<InheritanceDecl>())
{
@@ -1476,8 +1510,7 @@ namespace Slang
}
}
}
-
- // We go through all members and generate their differential counterparts.
+ // Go through all var members.
for (auto member : context->parentDecl->getMembersOfType<VarDeclBase>())
{
auto diffType = tryGetDifferentialType(m_astBuilder, member->type.type);
@@ -1488,22 +1521,9 @@ namespace Slang
addModifier(aggTypeDecl, m_astBuilder->create<SynthesizedModifier>());
- // If `This` is nested inside a generic, we need to form a complete declref type to the
- // newly synthesized aggTypeDecl here. This can be done by obtaining ThisTypeSubstitution
- // from requirementDeclRef to get the generic substitution for outer generic parameters, and
- // apply it to the newly synthesized decl.
- SubstitutionSet substSet;
- if (auto thisTypeSusbt = findThisTypeSubstitution(
- requirementDeclRef.substitutions,
- as<InterfaceDecl>(requirementDeclRef.getDecl()->parentDecl)))
- {
- if (auto declRefType = as<DeclRefType>(thisTypeSusbt->witness->sub))
- {
- substSet = declRefType->declRef.substitutions;
- }
- }
-
- auto satisfyingType = m_astBuilder->getOrCreateDeclRefType(aggTypeDecl, substSet);
+ // Synthesize the rest of IDifferential method conformances by recursively checking
+ // conformance on the synthesized decl.
+ checkAggTypeConformance(aggTypeDecl);
if (doesTypeSatisfyAssociatedTypeConstraintRequirement(satisfyingType, requirementDeclRef, witnessTable))
{
@@ -1616,6 +1636,50 @@ namespace Slang
}
};
+ // Check that types used as `Differential` type use themselves as their own `Differential` type.
+ struct SemanticsDeclDifferentialConformanceVisitor
+ : public SemanticsDeclVisitorBase
+ , public DeclVisitor<SemanticsDeclDifferentialConformanceVisitor>
+ {
+ SemanticsDeclDifferentialConformanceVisitor(SemanticsContext const& outer)
+ : SemanticsDeclVisitorBase(outer)
+ {}
+ void visitDecl(Decl*) {}
+ void visitDeclGroup(DeclGroup*) {}
+
+ void visitInheritanceDecl(InheritanceDecl* inheritanceDecl)
+ {
+ if (as<InterfaceDecl>(inheritanceDecl->parentDecl))
+ return;
+
+ if (!inheritanceDecl->witnessTable)
+ return;
+ auto baseType = as<DeclRefType>(inheritanceDecl->witnessTable->baseType);
+ if (!baseType)
+ return;
+ if (baseType->declRef.getDecl() != m_astBuilder->getDifferentiableInterface().getDecl())
+ return;
+ RequirementWitness witnessValue;
+ auto requirementDecl = m_astBuilder->getSharedASTBuilder()->findBuiltinRequirementDecl(BuiltinRequirementKind::DifferentialType);
+ if (!inheritanceDecl->witnessTable->requirementDictionary.TryGetValue(requirementDecl, witnessValue))
+ return;
+
+ // A type used as differential type must have itself as its own differential type.
+ if (witnessValue.getFlavor() != RequirementWitness::Flavor::val)
+ return;
+ auto differentialType = as<DeclRefType>(witnessValue.getVal());
+ if (!differentialType)
+ return;
+ auto diffDiffType = tryGetDifferentialType(m_astBuilder, differentialType);
+ if (!differentialType->equals(diffDiffType))
+ {
+ SourceLoc sourceLoc = differentialType->declRef.getDecl()->loc;
+ getSink()->diagnose(sourceLoc, Diagnostics::differentialTypeShouldServeAsItsOwnDifferentialType, differentialType);
+ getSink()->diagnose(inheritanceDecl, Diagnostics::noteSeeUseOfDifferentialType, differentialType, inheritanceDecl->getSup());
+ }
+ }
+ };
+
/// Recursively register any builtin declarations that need to be attached to the `session`.
///
/// This function should only be needed for declarations in the standard library.
@@ -1632,7 +1696,10 @@ namespace Slang
{
sharedASTBuilder->registerMagicDecl(decl, magicMod);
}
-
+ if (auto builtinRequirement = decl->findModifier<BuiltinRequirementModifier>())
+ {
+ sharedASTBuilder->registerBuiltinRequirementDecl(decl, builtinRequirement);
+ }
if(auto containerDecl = as<ContainerDecl>(decl))
{
for(auto childDecl : containerDecl->members)
@@ -2217,13 +2284,14 @@ namespace Slang
// associated type and see if they can be satisfied.
//
bool conformance = true;
+ Val* witness = nullptr;
for (auto requiredConstraintDeclRef : getMembersOfType<TypeConstraintDecl>(requiredAssociatedTypeDeclRef))
{
// Grab the type we expect to conform to from the constraint.
auto requiredSuperType = getSup(m_astBuilder, requiredConstraintDeclRef);
// Perform a search for a witness to the subtype relationship.
- auto witness = tryGetSubtypeWitness(satisfyingType, requiredSuperType);
+ witness = tryGetSubtypeWitness(satisfyingType, requiredSuperType);
if (witness)
{
// If a subtype witness was found, then the conformance
@@ -3040,7 +3108,7 @@ namespace Slang
witnessTable))
return true;
- if (auto builtinAttr = requiredFuncDeclRef.getDecl()->findModifier<BuiltinRequirementAttribute>())
+ if (auto builtinAttr = requiredFuncDeclRef.getDecl()->findModifier<BuiltinRequirementModifier>())
{
switch (builtinAttr->kind)
{
@@ -3067,7 +3135,7 @@ namespace Slang
if (auto requiredAssocTypeDeclRef = requiredMemberDeclRef.as<AssocTypeDecl>())
{
- if (auto builtinAttr = requiredAssocTypeDeclRef.getDecl()->findModifier<BuiltinRequirementAttribute>())
+ if (auto builtinAttr = requiredAssocTypeDeclRef.getDecl()->findModifier<BuiltinRequirementModifier>())
{
switch (builtinAttr->kind)
{
@@ -3160,7 +3228,7 @@ namespace Slang
bool hasDifferentialAssocType = false;
for (auto existingEntry : witnessTable->requirementList)
{
- if (auto builtinReqAttr = existingEntry.Key->findModifier<BuiltinRequirementAttribute>())
+ if (auto builtinReqAttr = existingEntry.Key->findModifier<BuiltinRequirementModifier>())
{
if (builtinReqAttr->kind == BuiltinRequirementKind::DifferentialType &&
existingEntry.Value.getFlavor() != RequirementWitness::Flavor::none)
@@ -3401,7 +3469,7 @@ namespace Slang
// requirement, it may be possible that we can still synthesis the
// implementation if this is one of the known builtin requirements.
// Otherwise, report diagnostic now.
- if (!requiredMemberDeclRef.getDecl()->hasModifier<BuiltinRequirementAttribute>())
+ if (!requiredMemberDeclRef.getDecl()->hasModifier<BuiltinRequirementModifier>())
{
getSink()->diagnose(inheritanceDecl, Diagnostics::typeDoesntImplementInterfaceRequirement, subType, requiredMemberDeclRef);
getSink()->diagnose(requiredMemberDeclRef, Diagnostics::seeDeclarationOf, requiredMemberDeclRef);
@@ -4499,11 +4567,29 @@ namespace Slang
getSink()->diagnose(decl, Slang::Diagnostics::assocTypeInInterfaceOnly);
}
+ void SemanticsDeclBodyVisitor::_maybeRegisterDifferentialBottomTypeConformance(SemanticsContext& context)
+ {
+ auto parentDifferentiableAttr = context.getParentDifferentiableAttribute();
+ if (parentDifferentiableAttr)
+ {
+ auto diffBottomType = m_astBuilder->getDifferentialBottomType();
+ auto idifferentiable = DeclRef<InterfaceDecl>(m_astBuilder->getDifferentiableInterface(), nullptr);
+ auto witness = as<SubtypeWitness>(tryGetInterfaceConformanceWitness(diffBottomType, idifferentiable));
+ SLANG_ASSERT(witness);
+ parentDifferentiableAttr->m_mapTypeToIDifferentiableWitness.Add(
+ as<DeclRefType>(diffBottomType)->declRef,
+ witness);
+ }
+ }
+
void SemanticsDeclBodyVisitor::visitFunctionDeclBase(FunctionDeclBase* decl)
{
+ auto newContext = withParentFunc(decl);
+ _maybeRegisterDifferentialBottomTypeConformance(newContext);
+
if (auto body = decl->body)
{
- checkBodyStmt(body, decl);
+ checkStmt(decl->body, newContext);
}
}
@@ -6234,6 +6320,7 @@ namespace Slang
case DeclCheckState::TypesFullyResolved:
SemanticsDeclTypeResolutionVisitor(shared).dispatch(decl);
+ SemanticsDeclDifferentialConformanceVisitor(shared).dispatch(decl);
break;
case DeclCheckState::Checked:
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index ad199300a..09dd9eea1 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -428,7 +428,7 @@ namespace Slang
// We will only ever need to synthesis a type to satisfy an associatedtype requirement.
// In this case the lookup should have resolved to a known associatedtype decl.
- auto builtinAssocTypeAttr = item.declRef.getDecl()->findModifier<BuiltinRequirementAttribute>();
+ auto builtinAssocTypeAttr = item.declRef.getDecl()->findModifier<BuiltinRequirementModifier>();
if (!builtinAssocTypeAttr)
return nullptr;
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index a0141911a..76918ebbe 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -214,6 +214,7 @@ namespace Slang
Dictionary<OperatorOverloadCacheKey, OverloadCandidate> resolvedOperatorOverloadCache;
Dictionary<BasicTypeKeyPair, ConversionCost> conversionCostCache;
};
+
/// Shared state for a semantics-checking session.
struct SharedSemanticsContext
{
@@ -274,7 +275,6 @@ namespace Slang
return m_linkage->isInLanguageServer();
return false;
}
-
/// Get the list of extension declarations that appear to apply to `decl` in this context
List<ExtensionDecl*> const& getCandidateExtensionsForTypeDecl(AggTypeDecl* decl);
@@ -375,6 +375,11 @@ namespace Slang
return result;
}
+ DifferentiableAttribute* getParentDifferentiableAttribute()
+ {
+ return m_parentDifferentiableAttr;
+ }
+
/// A scope that is local to a particular expression, and
/// that can be used to allocate temporary bindings that
/// might be needed by that expression or its sub-expressions.
@@ -1041,6 +1046,15 @@ namespace Slang
DeclRef<AssocTypeDecl> requirementDeclRef,
RefPtr<WitnessTable> witnessTable);
+ struct DifferentiableMemberInfo
+ {
+ Decl* memberDecl;
+ Type* diffType;
+ };
+
+ /// Gather differentiable members from decl.
+ List<DifferentiableMemberInfo> collectDifferentiableMemberInfo(ContainerDecl* decl);
+
// Find the appropriate member of a declared type to
// satisfy a requirement of an interface the type
// claims to conform to.
diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp
index 91f655a15..d8b05198c 100644
--- a/source/slang/slang-check-modifier.cpp
+++ b/source/slang/slang-check-modifier.cpp
@@ -484,27 +484,6 @@ namespace Slang
return false;
}
}
- else if (auto builtinAssocTypeAttr = as<BuiltinRequirementAttribute>(attr))
- {
- if (attr->args.getCount() == 1)
- {
- //IntVal* outIntVal;
- if (auto cInt = checkConstantEnumVal(attr->args[0]))
- {
- builtinAssocTypeAttr->kind = (BuiltinRequirementKind)(cInt->value);
- }
- else
- {
- getSink()->diagnose(attr, Diagnostics::expectedSingleIntArg, attr->keywordName);
- return false;
- }
- }
- else
- {
- getSink()->diagnose(attr, Diagnostics::expectedSingleIntArg, attr->keywordName);
- return false;
- }
- }
else if (auto unrollAttr = as<UnrollAttribute>(attr))
{
// Check has an argument. We need this because default behavior is to give an error
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index 9e939e476..ffee0622c 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -303,6 +303,9 @@ DIAGNOSTIC(30093, Error, uncaughtTryCallInNonThrowFunc, "the current function or
DIAGNOSTIC(30094, Error, mustUseTryClauseToCallAThrowFunc, "the callee may throw an error, and therefore must be called within a 'try' clause")
DIAGNOSTIC(30095, Error, errorTypeOfCalleeIncompatibleWithCaller, "the error type `$1` of callee `$0` is not compatible with the caller's error type `$2`.")
+DIAGNOSTIC(30096, Error, differentialTypeShouldServeAsItsOwnDifferentialType, "type '$0' is used as a `Differential` type, therefore it must serve as its own `Differential` type.")
+DIAGNOSTIC(-1, Note, noteSeeUseOfDifferentialType, "see use of '$0' as Differential of '$1'.")
+
// Attributes
DIAGNOSTIC(31000, Error, unknownAttributeName, "unknown attribute '$0'")
DIAGNOSTIC(31001, Error, attributeArgumentCountMismatch, "attribute '$0' expects $1 arguments ($2 provided)")
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index fcdee78ea..9c72f1d63 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -381,9 +381,6 @@ Result linkAndOptimizeIR(
// 2. Transpose JVP to VJP code wherever needed. (Transposition or "reverse-mode" pass)
// processVJPDerivativeMarkers(module); // Disabled currently. No impl yet.
-
- // 3. Fill in higher-order invocations with the generated functions.
- processDerivativeCalls(irModule);
stripAutoDiffDecorations(irModule);
diff --git a/source/slang/slang-ir-cleanup-void.cpp b/source/slang/slang-ir-cleanup-void.cpp
index ac520c1d5..a72157a69 100644
--- a/source/slang/slang-ir-cleanup-void.cpp
+++ b/source/slang/slang-ir-cleanup-void.cpp
@@ -36,26 +36,26 @@ namespace Slang
switch (inst->getOp())
{
case kIROp_Call:
+ case kIROp_makeStruct:
{
// Remove void argument.
- auto call = as<IRCall>(inst);
List<IRInst*> newArgs;
- for (UInt i = 0; i < call->getArgCount(); i++)
+ for (UInt i = 0; i < inst->getOperandCount(); i++)
{
- auto arg = call->getArg(i);
+ auto arg = inst->getOperand(i);
if (arg->getDataType() && arg->getDataType()->getOp() == kIROp_VoidType)
{
continue;
}
newArgs.add(arg);
}
- if (newArgs.getCount() != (Index)call->getArgCount())
+ if (newArgs.getCount() != (Index)inst->getOperandCount())
{
IRBuilder builder(&sharedBuilderStorage);
- builder.setInsertBefore(call);
- auto newCall = builder.emitCallInst(call->getFullType(), call->getCallee(), newArgs);
- call->replaceUsesWith(newCall);
- call->removeAndDeallocate();
+ builder.setInsertBefore(inst);
+ auto newCall = builder.emitIntrinsicInst(inst->getFullType(), inst->getOp(), newArgs.getCount(), newArgs.getBuffer());
+ inst->replaceUsesWith(newCall);
+ inst->removeAndDeallocate();
inst = newCall;
}
}
@@ -111,16 +111,43 @@ namespace Slang
break;
case kIROp_StructType:
{
- // TODO: cleanup void fields.
+ List<IRInst*> toRemove;
+ for (auto child : inst->getChildren())
+ {
+ if (auto field = as<IRStructField>(child))
+ {
+ if (field->getFieldType()->getOp() == kIROp_VoidType)
+ {
+ toRemove.add(field);
+ }
+ }
+ }
+ for (auto ii : toRemove)
+ ii->removeAndDeallocate();
}
break;
default:
break;
}
- // TODO: If inst has void type, all uses of it should be replaced with void val.
+ // If inst has void type, all uses of it should be replaced with void val.
// We should do this only for a subset of opcodes known to be safe.
-
+ switch(inst->getOp())
+ {
+ case kIROp_Load:
+ case kIROp_getElement:
+ case kIROp_GetOptionalValue:
+ case kIROp_FieldExtract:
+ case kIROp_GetTupleElement:
+ case kIROp_GetResultError:
+ case kIROp_GetResultValue:
+ if (inst->getDataType()->getOp() == kIROp_VoidType)
+ {
+ IRBuilder builder(&sharedBuilderStorage);
+ builder.setInsertBefore(inst);
+ inst->replaceUsesWith(builder.getVoidValue());
+ }
+ }
}
void processModule()
diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp
index d0bf8f347..8a4fe23d0 100644
--- a/source/slang/slang-ir-diff-jvp.cpp
+++ b/source/slang/slang-ir-diff-jvp.cpp
@@ -7,6 +7,7 @@
#include "slang-ir-dce.h"
#include "slang-ir-eliminate-phis.h"
#include "slang-ir-util.h"
+#include "slang-ir-inst-pass-base.h"
// origX, primalX, diffX
// origX -> primalX (cloneEnv)
@@ -20,9 +21,19 @@ struct Pair
{
P primal;
D differential;
-
+ Pair() = default;
Pair(P primal, D differential) : primal(primal), differential(differential)
{}
+ HashCode getHashCode() const
+ {
+ Hasher hasher;
+ hasher << primal << differential;
+ return hasher.getResult();
+ }
+ bool operator ==(const Pair& other) const
+ {
+ return primal == other.primal && differential == other.differential;
+ }
};
typedef Pair<IRInst*, IRInst*> InstPair;
@@ -43,6 +54,11 @@ struct AutoDiffSharedContext
//
IRStructKey* differentialAssocTypeStructKey = nullptr;
+ // The struct key for the witness that `Differential` associated type conforms to
+ // `IDifferential`.
+ IRStructKey* differentialAssocTypeWitnessStructKey = nullptr;
+
+
// The struct key for the 'zero()' associated type
// defined inside IDifferential. We use this to lookup the
// implementation of zero() for a given type.
@@ -54,6 +70,9 @@ struct AutoDiffSharedContext
// implementation of add() for a given type.
//
IRStructKey* addMethodStructKey = nullptr;
+
+ IRStructKey* mulMethodStructKey = nullptr;
+
// Modules that don't use differentiable types
// won't have the IDifferentiable interface type available.
@@ -69,8 +88,10 @@ struct AutoDiffSharedContext
if (differentiableInterfaceType)
{
differentialAssocTypeStructKey = findDifferentialTypeStructKey();
+ differentialAssocTypeWitnessStructKey = findDifferentialTypeWitnessStructKey();
zeroMethodStructKey = findZeroMethodStructKey();
addMethodStructKey = findAddMethodStructKey();
+ mulMethodStructKey = findMulMethodStructKey();
if (differentialAssocTypeStructKey)
isInterfaceAvailable = true;
@@ -103,22 +124,32 @@ struct AutoDiffSharedContext
return getIDifferentiableStructKeyAtIndex(0);
}
- IRStructKey* findZeroMethodStructKey()
+ IRStructKey* findDifferentialTypeWitnessStructKey()
{
return getIDifferentiableStructKeyAtIndex(1);
}
- IRStructKey* findAddMethodStructKey()
+ IRStructKey* findZeroMethodStructKey()
{
return getIDifferentiableStructKeyAtIndex(2);
}
+ IRStructKey* findAddMethodStructKey()
+ {
+ return getIDifferentiableStructKeyAtIndex(3);
+ }
+
+ IRStructKey* findMulMethodStructKey()
+ {
+ return getIDifferentiableStructKeyAtIndex(4);
+ }
+
IRStructKey* getIDifferentiableStructKeyAtIndex(UInt index)
{
if (as<IRModuleInst>(moduleInst) && differentiableInterfaceType)
{
- // Assume for now that IDifferentiable has exactly four fields.
- SLANG_ASSERT(differentiableInterfaceType->getOperandCount() == 4);
+ // Assume for now that IDifferentiable has exactly five fields.
+ SLANG_ASSERT(differentiableInterfaceType->getOperandCount() == 5);
if (auto entry = as<IRInterfaceRequirementEntry>(differentiableInterfaceType->getOperand(index)))
return as<IRStructKey>(entry->getRequirementKey());
else
@@ -300,7 +331,16 @@ struct DifferentialPairTypeBuilder
IRInst* emitFieldAccessor(IRBuilder* builder, IRInst* baseInst, IRStructKey* key)
{
- if (auto basePairStructType = as<IRStructType>(baseInst->getDataType()))
+ auto baseTypeInfo = lowerDiffPairType(builder, baseInst->getDataType());
+ if (baseTypeInfo.isTrivial)
+ {
+ if (key == globalPrimalKey)
+ return baseInst;
+ else
+ return builder->getDifferentialBottom();
+ }
+
+ if (auto basePairStructType = as<IRStructType>(baseTypeInfo.loweredType))
{
return as<IRFieldExtract>(builder->emitFieldExtract(
findField(basePairStructType, key)->getFieldType(),
@@ -308,7 +348,7 @@ struct DifferentialPairTypeBuilder
key
));
}
- else if (auto ptrType = as<IRPtrTypeBase>(baseInst->getDataType()))
+ else if (auto ptrType = as<IRPtrTypeBase>(baseTypeInfo.loweredType))
{
if (auto ptrInnerSpecializedType = as<IRSpecialize>(ptrType->getValueType()))
{
@@ -334,7 +374,7 @@ struct DifferentialPairTypeBuilder
key));
}
}
- else if (auto specializedType = as<IRSpecialize>(baseInst->getDataType()))
+ else if (auto specializedType = as<IRSpecialize>(baseTypeInfo.loweredType))
{
// TODO: Stopped here -> The type being emitted is incorrect. don't emit the generic's
// type, emit the specialization type.
@@ -420,25 +460,64 @@ struct DifferentialPairTypeBuilder
{
SLANG_ASSERT(!as<IRParam>(origBaseType));
SLANG_ASSERT(diffType);
- auto pairStructType = builder->createStructType();
- builder->createStructField(pairStructType, _getOrCreatePrimalStructKey(builder), origBaseType);
- builder->createStructField(pairStructType, _getOrCreateDiffStructKey(builder), (IRType*)diffType);
+ if (diffType->getOp() != kIROp_DifferentialBottomType)
+ {
+ auto pairStructType = builder->createStructType();
+ builder->createStructField(pairStructType, _getOrCreatePrimalStructKey(builder), origBaseType);
+ builder->createStructField(pairStructType, _getOrCreateDiffStructKey(builder), (IRType*)diffType);
+ return pairStructType;
+ }
+ return origBaseType;
+ }
- return pairStructType;
+ struct LoweredPairTypeInfo
+ {
+ IRInst* loweredType;
+ bool isTrivial;
+ };
+
+ IRInst* getDiffTypeFromPairType(IRBuilder* builder, IRDifferentialPairType* type)
+ {
+ auto witnessTable = type->getWitness();
+ return _lookupWitness(builder, witnessTable, sharedContext->differentialAssocTypeStructKey);
}
- IRInst* getOrCreateDiffPairType(IRBuilder* builder, IRType* origBaseType, IRType* diffType)
+ IRInst* getDiffTypeWitnessFromPairType(IRBuilder* builder, IRDifferentialPairType* type)
{
- if (pairTypeCache.ContainsKey(origBaseType))
- return pairTypeCache[origBaseType];
+ auto witnessTable = type->getWitness();
+ return _lookupWitness(builder, witnessTable, sharedContext->differentialAssocTypeWitnessStructKey);
+ }
- auto pairType = _createDiffPairType(builder, origBaseType, diffType);
- pairTypeCache.Add(origBaseType, pairType);
+ LoweredPairTypeInfo lowerDiffPairType(IRBuilder* builder, IRType* originalPairType)
+ {
+ LoweredPairTypeInfo result = {};
+
+ if (pairTypeCache.TryGetValue(originalPairType, result))
+ return result;
+ auto pairType = as<IRDifferentialPairType>(originalPairType);
+ if (!pairType)
+ {
+ result.isTrivial = true;
+ result.loweredType = originalPairType;
+ return result;
+ }
+ auto primalType = pairType->getValueType();
+ if (as<IRParam>(primalType))
+ {
+ result.isTrivial = false;
+ result.loweredType = nullptr;
+ return result;
+ }
+
+ auto diffType = getDiffTypeFromPairType(builder, pairType);
+ result.loweredType = _createDiffPairType(builder, pairType->getValueType(), (IRType*)diffType);
+ result.isTrivial = (diffType->getOp() == kIROp_DifferentialBottomType);
+ pairTypeCache.Add(originalPairType, result);
- return pairType;
+ return result;
}
- Dictionary<IRInst*, IRInst*> pairTypeCache;
+ Dictionary<IRInst*, LoweredPairTypeInfo> pairTypeCache;
IRStructKey* globalPrimalKey = nullptr;
@@ -447,6 +526,8 @@ struct DifferentialPairTypeBuilder
IRInst* genericDiffPairType = nullptr;
List<IRInst*> generatedTypeList;
+
+ AutoDiffSharedContext* sharedContext = nullptr;
};
struct JVPTranscriber
@@ -474,8 +555,15 @@ struct JVPTranscriber
DifferentiableTypeConformanceContext differentiableTypeConformanceContext;
- JVPTranscriber(AutoDiffSharedContext* shared)
- : differentiableTypeConformanceContext(shared)
+ List<InstPair> followUpFunctionsToTranscribe;
+
+ SharedIRBuilder* sharedBuilder;
+ // Witness table that `DifferentialBottom:IDifferential`.
+ IRWitnessTable* differentialBottomWitness = nullptr;
+ Dictionary<InstPair, IRInst*> differentialPairTypes;
+
+ JVPTranscriber(AutoDiffSharedContext* shared, SharedIRBuilder* inSharedBuilder)
+ : differentiableTypeConformanceContext(shared), sharedBuilder(inSharedBuilder)
{}
DiagnosticSink* getSink()
@@ -592,8 +680,75 @@ struct JVPTranscriber
return builder->getFuncType(newParameterTypes, diffReturnType);
}
+ IRWitnessTable* getDifferentialBottomWitness()
+ {
+ IRBuilder builder(sharedBuilder);
+ builder.setInsertInto(sharedBuilder->getModule()->getModuleInst());
+ auto result =
+ as<IRWitnessTable>(differentiableTypeConformanceContext.lookUpConformanceForType(
+ builder.getDifferentialBottomType()));
+ SLANG_ASSERT(result);
+ return result;
+ }
+
+ // Get or construct `:IDifferentiable` conformance for a DifferentiablePair.
+ IRWitnessTable* getDifferentialPairWitness(IRInst* inDiffPairType)
+ {
+ IRBuilder builder(sharedBuilder);
+ builder.setInsertInto(inDiffPairType->parent);
+ auto diffPairType = as<IRDifferentialPairType>(inDiffPairType);
+ SLANG_ASSERT(diffPairType);
+ auto result =
+ as<IRWitnessTable>(differentiableTypeConformanceContext.lookUpConformanceForType(
+ builder.getDifferentialBottomType()));
+ if (result)
+ return result;
+
+ auto table = builder.createWitnessTable(autoDiffSharedContext->differentiableInterfaceType, diffPairType);
+ auto diffType = differentiateType(&builder, diffPairType->getValueType());
+ auto differentialType = builder.getDifferentialPairType(diffType, getDifferentialBottomWitness());
+ builder.createWitnessTableEntry(table, autoDiffSharedContext->differentialAssocTypeStructKey, differentialType);
+ // Omit the method synthesis here, since we can just intercept those directly at `getXXMethodForType`.
+
+ differentiableTypeConformanceContext.differentiableWitnessDictionary[diffPairType] = table;
+ return table;
+ }
+
+ IRType* getOrCreateDiffPairType(IRInst* primalType, IRInst* witness)
+ {
+ IRBuilder builder(sharedBuilder);
+ builder.setInsertInto(primalType->parent);
+ return builder.getDifferentialPairType(
+ (IRType*)primalType,
+ witness);
+ }
+
+ IRType* getOrCreateDiffPairType(IRInst* primalType)
+ {
+ IRBuilder builder(sharedBuilder);
+ builder.setInsertInto(primalType->parent);
+ auto witness = as<IRWitnessTable>(
+ differentiableTypeConformanceContext.lookUpConformanceForType((IRType*)primalType));
+ if (!witness)
+ witness = getDifferentialBottomWitness();
+ return builder.getDifferentialPairType(
+ (IRType*)primalType,
+ witness);
+ }
+
IRType* differentiateType(IRBuilder* builder, IRType* origType)
{
+ IRInst* diffType = nullptr;
+ if (!instMapD.TryGetValue(origType, diffType))
+ {
+ diffType = _differentiateTypeImpl(builder, origType);
+ instMapD[origType] = diffType;
+ }
+ return (IRType*)diffType;
+ }
+
+ IRType* _differentiateTypeImpl(IRBuilder* builder, IRType* origType)
+ {
if (auto ptrType = as<IRPtrTypeBase>(origType))
return builder->getPtrType(
origType->getOp(),
@@ -628,6 +783,14 @@ struct JVPTranscriber
else
return nullptr;
}
+
+ case kIROp_DifferentialPairType:
+ {
+ auto primalPairType = as<IRDifferentialPairType>(primalType);
+ return getOrCreateDiffPairType(
+ pairBuilder->getDiffTypeFromPairType(builder, primalPairType),
+ pairBuilder->getDiffTypeWitnessFromPairType(builder, primalPairType));
+ }
case kIROp_FuncType:
return differentiateFunctionType(builder, as<IRFuncType>(primalType));
@@ -660,7 +823,7 @@ struct JVPTranscriber
return (IRType*)(differentiableTypeConformanceContext.getDifferentialForType(builder, (IRType*)primalType));
}
}
-
+
IRType* tryGetDiffPairType(IRBuilder* builder, IRType* primalType)
{
// If this is a PtrType (out, inout, etc..), then create diff pair from
@@ -675,7 +838,7 @@ struct JVPTranscriber
}
auto diffType = differentiateType(builder, primalType);
if (diffType)
- return (IRType*)pairBuilder->getOrCreateDiffPairType(builder, primalType, diffType);
+ return (IRType*)getOrCreateDiffPairType(primalType);
return nullptr;
}
@@ -692,7 +855,7 @@ struct JVPTranscriber
if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)primalDataType))
{
- IRParam* diffPairParam = builder->emitParam(diffPairType);
+ IRInst* diffPairParam = builder->emitParam(diffPairType);
auto diffPairVarName = makeDiffPairName(origParam);
if (diffPairVarName.getLength() > 0)
@@ -700,9 +863,20 @@ struct JVPTranscriber
SLANG_ASSERT(diffPairParam);
- return InstPair(
- pairBuilder->emitPrimalFieldAccess(builder, diffPairParam),
- pairBuilder->emitDiffFieldAccess(builder, diffPairParam));
+ if (auto pairType = as<IRDifferentialPairType>(diffPairParam->getDataType()))
+ {
+ return InstPair(
+ builder->emitDifferentialPairGetPrimal(diffPairParam),
+ builder->emitDifferentialPairGetDifferential(
+ (IRType*)pairBuilder->getDiffTypeFromPairType(builder, pairType),
+ diffPairParam));
+ }
+ // If this is an `in/inout DifferentialPair<>` parameter, we can't produce
+ // its primal and diff parts right now because they would represent a reference
+ // to a pair field, which doesn't make sense since pair types are considered mutable.
+ // We encode the result as if the param is non-differentiable, and handle it
+ // with special care at load/store.
+ return InstPair(diffPairParam, nullptr);
}
@@ -826,30 +1000,52 @@ struct JVPTranscriber
InstPair transcribeLoad(IRBuilder* builder, IRLoad* origLoad)
{
auto origPtr = origLoad->getPtr();
-
- auto primalLoad = cloneInst(&cloneEnv, builder, origLoad);
+ auto primalPtr = lookupPrimalInst(origPtr, nullptr);
+ auto primalPtrValueType = as<IRPtrTypeBase>(primalPtr->getFullType())->getValueType();
- IRInst* diffLoad = nullptr;
+ if (auto diffPairType = as<IRDifferentialPairType>(primalPtrValueType))
+ {
+ // Special case load from an `out` param, which will not have corresponding `diff` and
+ // `primal` insts yet.
+ auto load = builder->emitLoad(primalPtr);
+ auto primalElement = builder->emitDifferentialPairGetPrimal(load);
+ auto diffElement = builder->emitDifferentialPairGetDifferential(
+ (IRType*)pairBuilder->getDiffTypeFromPairType(builder, diffPairType), load);
+ return InstPair(primalElement, diffElement);
+ }
+ auto primalLoad = cloneInst(&cloneEnv, builder, origLoad);
+ IRInst* diffLoad = nullptr;
if (auto diffPtr = lookupDiffInst(origPtr, nullptr))
{
// Default case, we're loading from a known differential inst.
diffLoad = as<IRLoad>(builder->emitLoad(diffPtr));
- return InstPair(primalLoad, diffLoad);
- }
- return InstPair(primalLoad, nullptr);
+ }
+ return InstPair(primalLoad, diffLoad);
}
InstPair transcribeStore(IRBuilder* builder, IRStore* origStore)
{
IRInst* origStoreLocation = origStore->getPtr();
IRInst* origStoreVal = origStore->getVal();
-
- auto primalStore = cloneInst(&cloneEnv, builder, origStore);
-
+ auto primalStoreLocation = lookupPrimalInst(origStoreLocation, nullptr);
auto diffStoreLocation = lookupDiffInst(origStoreLocation, nullptr);
+ auto primalStoreVal = lookupPrimalInst(origStoreVal, nullptr);
auto diffStoreVal = lookupDiffInst(origStoreVal, nullptr);
+ if (!diffStoreLocation)
+ {
+ auto primalLocationPtrType = as<IRPtrTypeBase>(primalStoreLocation->getDataType());
+ if (auto diffPairType = as<IRDifferentialPairType>(primalLocationPtrType->getValueType()))
+ {
+ auto valToStore = builder->emitMakeDifferentialPair(diffPairType, primalStoreVal, diffStoreVal);
+ auto store = builder->emitStore(primalStoreLocation, valToStore);
+ return InstPair(store, nullptr);
+ }
+ }
+
+ auto primalStore = cloneInst(&cloneEnv, builder, origStore);
+
IRInst* diffStore = nullptr;
// If the stored value has a differential version,
@@ -1052,8 +1248,9 @@ struct JVPTranscriber
if (diffReturnType->getOp() != kIROp_VoidType)
{
- IRInst* primalResultValue = pairBuilder->emitPrimalFieldAccess(builder, callInst);
- IRInst* diffResultValue = pairBuilder->emitDiffFieldAccess(builder, callInst);
+ IRInst* primalResultValue = builder->emitDifferentialPairGetPrimal(callInst);
+ auto diffType = differentiateType(builder, origCall->getFullType());
+ IRInst* diffResultValue = builder->emitDifferentialPairGetDifferential(diffType, callInst);
return InstPair(primalResultValue, diffResultValue);
}
else
@@ -1174,14 +1371,16 @@ struct JVPTranscriber
return InstPair(nullptr, nullptr);
}
- InstPair transcribeConst(IRBuilder*, IRInst* origInst)
+ InstPair transcribeConst(IRBuilder* builder, IRInst* origInst)
{
switch(origInst->getOp())
{
case kIROp_FloatLit:
+ return InstPair(origInst, builder->getFloatValue(origInst->getDataType(), 0.0f));
case kIROp_VoidLit:
+ return InstPair(origInst, origInst);
case kIROp_IntLit:
- return InstPair(origInst, nullptr);
+ return InstPair(origInst, builder->getIntValue(origInst->getDataType(), 0));
}
getSink()->diagnose(
@@ -1245,6 +1444,14 @@ struct JVPTranscriber
{
if (auto diffType = differentiateType(builder, primalType))
{
+ switch (diffType->getOp())
+ {
+ case kIROp_DifferentialPairType:
+ return builder->emitMakeDifferentialPair(
+ diffType,
+ getDifferentialZeroOfType(builder, as<IRDifferentialPairType>(diffType)->getValueType()),
+ getDifferentialZeroOfType(builder, as<IRDifferentialPairType>(diffType)->getValueType()));
+ }
// Since primalType has a corresponding differential type, we can lookup the
// definition for zero().
auto zeroMethod = differentiableTypeConformanceContext.getZeroMethodForType(builder, primalType);
@@ -1458,40 +1665,63 @@ struct JVPTranscriber
return InstPair(diffLoop, diffLoop);
}
- // Transcribe a function definition.
- InstPair transcribeFunc(IRBuilder* builder, IRFunc* origFunc)
+ InstPair transcribeMakeDifferentialPair(IRBuilder* builder, IRMakeDifferentialPair* origInst)
{
- IRFunc* primalFunc = nullptr;
+ auto primalVal = findOrTranscribePrimalInst(builder, origInst->getPrimalValue());
+ SLANG_ASSERT(primalVal);
+ auto diffPrimalVal = findOrTranscribePrimalInst(builder, origInst->getDifferentialValue());
+ SLANG_ASSERT(diffPrimalVal);
+ auto primalDiffVal = findOrTranscribeDiffInst(builder, origInst->getPrimalValue());
+ SLANG_ASSERT(primalDiffVal);
+ auto diffDiffVal = findOrTranscribeDiffInst(builder, origInst->getDifferentialValue());
+ SLANG_ASSERT(diffDiffVal);
- differentiableTypeConformanceContext.setFunc(origFunc);
+ auto primalPair = builder->emitMakeDifferentialPair(origInst->getDataType(), primalVal, diffPrimalVal);
+ auto diffPair = builder->emitMakeDifferentialPair(
+ differentiateType(builder, origInst->getDataType()),
+ primalDiffVal,
+ diffDiffVal);
+ return InstPair(primalPair, diffPair);
+ }
- auto oldLoc = builder->getInsertLoc();
+ InstPair transcribeDifferentialPairGetElement(IRBuilder* builder, IRInst* origInst)
+ {
+ SLANG_ASSERT(
+ origInst->getOp() == kIROp_DifferentialPairGetDifferential ||
+ origInst->getOp() == kIROp_DifferentialPairGetPrimal);
- // If this is a top-level function, there is no need to clone it
- // since it is visible in all the scopes.
- // Otherwise, we need to clone it in case of generic scopes.
- //
- // TODO(sai): Is this the correct thing to do? Can a function cloned inside a
- // generic scope but is not the return value of that generic, be used within
- // that scope? Or do we have to call out to the original generic specialized with
- // the current generic params?
- //
- bool isTopLevelFunc = (as<IRModuleInst>(origFunc->parent) != nullptr);
- if (isTopLevelFunc)
- {
- builder->setInsertBefore(origFunc);
- primalFunc = origFunc;
- }
+ auto primalVal = findOrTranscribePrimalInst(builder, origInst->getOperand(0));
+ SLANG_ASSERT(primalVal);
+
+ auto diffVal = findOrTranscribeDiffInst(builder, origInst->getOperand(0));
+ SLANG_ASSERT(diffVal);
+
+ auto primalResult = builder->emitIntrinsicInst(origInst->getFullType(), origInst->getOp(), 1, &primalVal);
+
+ auto diffValPairType = as<IRDifferentialPairType>(diffVal->getDataType());
+ IRInst* diffResultType = nullptr;
+ if (origInst->getOp() == kIROp_DifferentialPairGetDifferential)
+ diffResultType = pairBuilder->getDiffTypeFromPairType(builder, diffValPairType);
else
- {
- // TODO(sai): this might never be called, and it might never make sense
- // to call it either. Potentially remove this.
- primalFunc = as<IRFunc>(
- cloneInst(&cloneEnv, builder, origFunc));
- }
+ diffResultType = diffValPairType->getValueType();
+ auto diffResult = builder->emitIntrinsicInst((IRType*)diffResultType, origInst->getOp(), 1, &diffVal);
+ return InstPair(primalResult, diffResult);
+ }
+
+ // Create an empty func to represent the transcribed func of `origFunc`.
+ InstPair transcribeFuncHeader(IRBuilder* builder, IRFunc* origFunc)
+ {
+ auto oldLoc = builder->getInsertLoc();
+
+ IRFunc* primalFunc = origFunc;
+
+ differentiableTypeConformanceContext.setFunc(origFunc);
+
+ builder->setInsertBefore(origFunc);
+ primalFunc = origFunc;
auto diffFunc = builder->createFunc();
-
+
SLANG_ASSERT(as<IRFuncType>(origFunc->getFullType()));
IRType* diffFuncType = this->differentiateFunctionType(
builder,
@@ -1505,10 +1735,33 @@ struct JVPTranscriber
newNameSb << "s_jvp_" << originalName;
builder->addNameHintDecoration(diffFunc, newNameSb.getUnownedSlice());
}
-
+ builder->addForwardDerivativeDecoration(origFunc, diffFunc);
+
+ // Mark the generated derivative function itself as differentiable.
+ builder->addForwardDifferentiableDecoration(diffFunc);
+
+ // Find and clone `DifferentiableTypeDictionaryDecoration` to the new diffFunc.
+ if (auto dictDecor = origFunc->findDecoration<IRDifferentiableTypeDictionaryDecoration>())
+ {
+ cloneDecoration(dictDecor, diffFunc);
+ }
+
+ // Reset builder position
+ builder->setInsertLoc(oldLoc);
+ auto result = InstPair(primalFunc, diffFunc);
+ followUpFunctionsToTranscribe.add(result);
+ return result;
+ }
+
+ // Transcribe a function definition.
+ InstPair transcribeFunc(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffFunc)
+ {
+ auto oldLoc = builder->getInsertLoc();
+
+ differentiableTypeConformanceContext.setFunc(primalFunc);
// Transcribe children from origFunc into diffFunc
builder->setInsertInto(diffFunc);
- for (auto block = origFunc->getFirstBlock(); block; block = block->getNextBlock())
+ for (auto block = primalFunc->getFirstBlock(); block; block = block->getNextBlock())
this->transcribe(builder, block);
// Reset builder position
@@ -1685,6 +1938,11 @@ struct JVPTranscriber
case kIROp_ifElse:
return transcribeIfElse(builder, as<IRIfElse>(origInst));
+ case kIROp_MakeDifferentialPair:
+ return transcribeMakeDifferentialPair(builder, as<IRMakeDifferentialPair>(origInst));
+ case kIROp_DifferentialPairGetPrimal:
+ case kIROp_DifferentialPairGetDifferential:
+ return transcribeDifferentialPairGetElement(builder, origInst);
}
// If none of the cases have been hit, check if the instruction is a
@@ -1722,7 +1980,7 @@ struct JVPTranscriber
switch (origInst->getOp())
{
case kIROp_Func:
- return transcribeFunc(builder, as<IRFunc>(origInst));
+ return transcribeFuncHeader(builder, as<IRFunc>(origInst));
case kIROp_Block:
return transcribeBlock(builder, as<IRBlock>(origInst));
@@ -1741,45 +1999,7 @@ struct JVPTranscriber
}
};
-struct IRWorkQueue
-{
- // Work list to hold the active set of insts whose children
- // need to be looked at.
- //
- List<IRInst*> workList;
- HashSet<IRInst*> workListSet;
-
- void push(IRInst* inst)
- {
- if(!inst) return;
- if(workListSet.Contains(inst)) return;
-
- workList.add(inst);
- workListSet.Add(inst);
- }
-
- IRInst* pop()
- {
- if (workList.getCount() != 0)
- {
- IRInst* topItem = workList.getFirst();
- // TODO(Sai): Repeatedly calling removeAt() can be really slow.
- // Consider a specialized data structure or using removeLast()
- //
- workList.removeAt(0);
- workListSet.Remove(topItem);
- return topItem;
- }
- return nullptr;
- }
-
- IRInst* peek()
- {
- return workList.getFirst();
- }
-};
-
-struct JVPDerivativeContext
+struct JVPDerivativeContext : public InstPassBase
{
DiagnosticSink* getSink()
@@ -1795,6 +2015,7 @@ struct JVPDerivativeContext
//
SharedIRBuilder* sharedBuilder = &sharedBuilderStorage;
sharedBuilder->init(module);
+ sharedBuilder->deduplicateAndRebuildGlobalNumberingMap();
IRBuilder builderStorage(sharedBuilderStorage);
IRBuilder* builder = &builderStorage;
@@ -1809,8 +2030,12 @@ struct JVPDerivativeContext
// IRDifferentialPairGetPrimal with 'primal' field access, and
// IRMakeDifferentialPair with an IRMakeStruct.
//
+ modified |= simplifyDifferentialBottomType(builder);
+
modified |= processPairTypes(builder, module->getModuleInst());
-
+
+ modified |= eliminateDifferentialBottomType(builder);
+
return modified;
}
@@ -1826,121 +2051,92 @@ struct JVPDerivativeContext
//
bool processReferencedFunctions(IRBuilder* builder)
{
- IRWorkQueue* workQueue = &(workQueueStorage);
+ List<IRForwardDifferentiate*> autoDiffWorkList;
- // Put the top-level inst into the queue.
- workQueue->push(module->getModuleInst());
-
- // Keep processing items until the queue is complete.
- while (IRInst* workItem = workQueue->pop())
- {
- for(auto child = workItem->getFirstChild(); child; child = child->getNextInst())
+ for (;;)
+ {
+ // Collect all `ForwardDifferentiate` insts from the module.
+ autoDiffWorkList.clear();
+ processInstsOfType<IRForwardDifferentiate>(kIROp_ForwardDifferentiate, [&](IRForwardDifferentiate* fwdDiffInst)
{
- // Either the child instruction has more children (func/block etc..)
- // and we add it to the work list for further processing, or
- // it's an ordinary inst in which case we check if it's a ForwardDifferentiate
- // instruction.
- //
- if (child->getFirstChild() != nullptr)
- workQueue->push(child);
-
- if (auto jvpDiffInst = as<IRForwardDifferentiate>(child))
- {
- auto baseInst = jvpDiffInst->getBaseFn();
+ autoDiffWorkList.add(fwdDiffInst);
+ });
- IRGlobalValueWithCode* baseFunction = nullptr;
+ if (autoDiffWorkList.getCount() == 0)
+ break;
- if (auto specializeInst = as<IRSpecialize>(baseInst))
- {
- // Certain specialize insts come with a derivative
- // reference attached. Skip such instructions.
- //
- if (lookupJVPReference(specializeInst)) continue;
- }
- else if (auto globalValWithCode = as<IRGlobalValueWithCode>(baseInst))
+ // Process collected `ForwardDifferentiate` insts and replace them with placeholders for
+ // differentiated functions.
+ transcriberStorage.followUpFunctionsToTranscribe.clear();
+
+ for (auto fwdDiffInst : autoDiffWorkList)
+ {
+ auto baseInst = fwdDiffInst->getBaseFn();
+ if (auto baseFunction = as<IRGlobalValueWithCode>(baseInst))
+ {
+ if (auto existingDiffFunc = lookupJVPReference(baseFunction))
{
- baseFunction = globalValWithCode;
+ fwdDiffInst->replaceUsesWith(existingDiffFunc);
+ fwdDiffInst->removeAndDeallocate();
}
-
- SLANG_ASSERT(baseFunction);
-
- // If the JVP Reference already exists, no need to
- // differentiate again.
- //
- if (lookupJVPReference(baseFunction)) continue;
-
- if (isMarkedForForwardDifferentiation(baseFunction))
+ else if (isMarkedForForwardDifferentiation(baseFunction))
{
if (as<IRFunc>(baseFunction) || as<IRGeneric>(baseFunction))
{
- IRInst* diffFunc = (&transcriberStorage)->transcribe(builder, baseFunction);
+ IRInst* diffFunc = transcriberStorage.transcribe(builder, baseFunction);
SLANG_ASSERT(diffFunc);
- builder->addForwardDerivativeDecoration(baseFunction, diffFunc);
- workQueue->push(diffFunc);
- }
+ fwdDiffInst->replaceUsesWith(diffFunc);
+ fwdDiffInst->removeAndDeallocate();
+ }
else
{
// TODO(Sai): This would probably be better with a more specific
// error code.
- getSink()->diagnose(jvpDiffInst->sourceLoc,
+ getSink()->diagnose(fwdDiffInst->sourceLoc,
Diagnostics::internalCompilerError,
"Unexpected instruction. Expected func or generic");
}
}
- else
+ else
{
// TODO(Sai): This would probably be better with a more specific
// error code.
- getSink()->diagnose(jvpDiffInst->sourceLoc,
+ getSink()->diagnose(fwdDiffInst->sourceLoc,
Diagnostics::internalCompilerError,
"Cannot differentiate functions not marked for differentiation");
}
}
}
- }
-
- return true;
- }
-
- IRInst* lowerPairType(IRBuilder* builder, IRType* type)
- {
-
- if (auto pairType = as<IRDifferentialPairType>(type))
- {
- builder->setInsertBefore(pairType);
-
- if (!as<IRType>(pairType->getValueType()))
+ // Actually synthesize the derivatives.
+ List<InstPair> followUpWorkList = _Move(transcriberStorage.followUpFunctionsToTranscribe);
+ for (auto task : followUpWorkList)
{
- return nullptr;
- }
- auto witness = pairType->getWitness();
- auto diffType = _lookupWitness(builder, witness, autoDiffSharedContextStorage.differentialAssocTypeStructKey);
- if (!diffType)
- {
- return nullptr;
+ auto diffFunc = as<IRFunc>(task.differential);
+ SLANG_ASSERT(diffFunc);
+ auto primalFunc = as<IRFunc>(task.primal);
+ SLANG_ASSERT(primalFunc);
+
+ transcriberStorage.transcribeFunc(builder, primalFunc, diffFunc);
}
- auto diffPairStructType = (&pairBuilderStorage)->getOrCreateDiffPairType(
- builder,
- pairType->getValueType(),
- (IRType*)(diffType));
- pairType->replaceUsesWith(diffPairStructType);
- pairType->removeAndDeallocate();
+ // Transcribing the function body really shouldn't produce more follow up function body work.
+ // However it may produce new `ForwardDifferentiate` instructions, which we collect and process
+ // in the next iteration.
+ SLANG_RELEASE_ASSERT(transcriberStorage.followUpFunctionsToTranscribe.getCount() == 0);
- return diffPairStructType;
- }
- else if (auto loweredStructType = as<IRStructType>(type))
- {
- // Already lowered to struct.
- return loweredStructType;
- }
- else if (auto specializedStructType = as<IRSpecialize>(type))
- {
- // Already lowered to specialized struct.
- return specializedStructType;
}
-
- return nullptr;
+ return true;
+ }
+
+ IRInst* lowerPairType(IRBuilder* builder, IRType* pairType, bool* isTrivial = nullptr)
+ {
+ builder->setInsertBefore(pairType);
+ auto loweredPairTypeInfo = (&pairBuilderStorage)->lowerDiffPairType(
+ builder,
+ pairType);
+ if (isTrivial)
+ *isTrivial = loweredPairTypeInfo.isTrivial;
+ return loweredPairTypeInfo.loweredType;
}
IRInst* lowerMakePair(IRBuilder* builder, IRInst* inst)
@@ -1948,19 +2144,24 @@ struct JVPDerivativeContext
if (auto makePairInst = as<IRMakeDifferentialPair>(inst))
{
- if (auto diffPairStructType = lowerPairType(builder, makePairInst->getDataType()))
+ bool isTrivial = false;
+ auto pairType = as<IRDifferentialPairType>(makePairInst->getDataType());
+ if (auto loweredPairType = lowerPairType(builder, pairType, &isTrivial))
{
builder->setInsertBefore(makePairInst);
-
- List<IRInst*> operands;
- operands.add(makePairInst->getPrimalValue());
- operands.add(makePairInst->getDifferentialValue());
-
- auto makeStructInst = builder->emitMakeStruct((IRType*)(diffPairStructType), operands);
- makePairInst->replaceUsesWith(makeStructInst);
+ IRInst* result = nullptr;
+ if (isTrivial)
+ {
+ result = makePairInst->getPrimalValue();
+ }
+ else
+ {
+ IRInst* operands[2] = { makePairInst->getPrimalValue(), makePairInst->getDifferentialValue() };
+ result = builder->emitMakeStruct((IRType*)(loweredPairType), 2, operands);
+ }
+ makePairInst->replaceUsesWith(result);
makePairInst->removeAndDeallocate();
-
- return makeStructInst;
+ return result;
}
}
@@ -1971,11 +2172,11 @@ struct JVPDerivativeContext
{
if (auto getDiffInst = as<IRDifferentialPairGetDifferential>(inst))
{
- if (lowerPairType(builder, getDiffInst->getBase()->getDataType()))
+ if (lowerPairType(builder, getDiffInst->getBase()->getDataType(), nullptr))
{
builder->setInsertBefore(getDiffInst);
-
- auto diffFieldExtract = (&pairBuilderStorage)->emitDiffFieldAccess(builder, getDiffInst->getBase());
+ IRInst* diffFieldExtract = nullptr;
+ diffFieldExtract = (&pairBuilderStorage)->emitDiffFieldAccess(builder, getDiffInst->getBase());
getDiffInst->replaceUsesWith(diffFieldExtract);
getDiffInst->removeAndDeallocate();
return diffFieldExtract;
@@ -1983,14 +2184,14 @@ struct JVPDerivativeContext
}
else if (auto getPrimalInst = as<IRDifferentialPairGetPrimal>(inst))
{
- if (lowerPairType(builder, getPrimalInst->getBase()->getDataType()))
+ if (lowerPairType(builder, getPrimalInst->getBase()->getDataType(), nullptr))
{
builder->setInsertBefore(getPrimalInst);
- auto primalFieldExtract = (&pairBuilderStorage)->emitPrimalFieldAccess(builder, getPrimalInst->getBase());
+ IRInst* primalFieldExtract = nullptr;
+ primalFieldExtract = (&pairBuilderStorage)->emitPrimalFieldAccess(builder, getPrimalInst->getBase());
getPrimalInst->replaceUsesWith(primalFieldExtract);
getPrimalInst->removeAndDeallocate();
-
return primalFieldExtract;
}
}
@@ -2001,40 +2202,195 @@ struct JVPDerivativeContext
bool processPairTypes(IRBuilder* builder, IRInst* instWithChildren)
{
bool modified = false;
+ // Hoist all pair types to global scope when possible.
+ auto moduleInst = module->getModuleInst();
+ processInstsOfType<IRDifferentialPairType>(kIROp_DifferentialPairType, [&](IRInst* originalPairType)
+ {
+ if (originalPairType->parent != moduleInst)
+ {
+ originalPairType->removeFromParent();
+ ShortList<IRInst*> operands;
+ for (UInt i = 0; i < originalPairType->getOperandCount(); i++)
+ {
+ operands.add(originalPairType->getOperand(i));
+ }
+ auto newPairType = builder->findOrEmitHoistableInst(
+ originalPairType->getFullType(),
+ originalPairType->getOp(),
+ originalPairType->getOperandCount(),
+ operands.getArrayView().getBuffer());
+ originalPairType->replaceUsesWith(newPairType);
+ originalPairType->removeAndDeallocate();
+ }
+ });
- for (auto child = instWithChildren->getFirstChild(); child; )
- {
- // Make sure the builder is at the right level.
- builder->setInsertInto(instWithChildren);
-
- auto nextChild = child->getNextInst();
+ sharedBuilderStorage.deduplicateAndRebuildGlobalNumberingMap();
- switch (child->getOp())
+ processAllInsts([&](IRInst* inst)
{
- case kIROp_DifferentialPairType:
- lowerPairType(builder, as<IRType>(child));
- break;
-
+ // Make sure the builder is at the right level.
+ builder->setInsertInto(instWithChildren);
+
+ switch (inst->getOp())
+ {
case kIROp_DifferentialPairGetDifferential:
case kIROp_DifferentialPairGetPrimal:
- lowerPairAccess(builder, child);
+ lowerPairAccess(builder, inst);
+ modified = true;
break;
-
+
case kIROp_MakeDifferentialPair:
- lowerMakePair(builder, child);
+ lowerMakePair(builder, inst);
+ modified = true;
break;
-
+
default:
- if (child->getFirstChild())
- modified = processPairTypes(builder, child) | modified;
- }
+ break;
+ }
+ });
- child = nextChild;
+ processInstsOfType<IRDifferentialPairType>(kIROp_DifferentialPairType, [&](IRDifferentialPairType* inst)
+ {
+ if (auto loweredType = lowerPairType(builder, inst))
+ {
+ inst->replaceUsesWith(loweredType);
+ inst->removeAndDeallocate();
+ }
+ });
+ return modified;
+ }
+
+ bool simplifyDifferentialBottomType(IRBuilder* builder)
+ {
+ bool modified = false;
+ auto diffBottom = builder->getDifferentialBottom();
+
+ bool changed = true;
+ List<IRUse*> uses;
+ while (changed)
+ {
+ changed = false;
+ // Replace all insts whose type is `DifferentialBottomType` to `diffBottom`.
+ processAllInsts([&](IRInst* inst)
+ {
+ if (inst->getDataType() && inst->getDataType()->getOp() == kIROp_DifferentialBottomType)
+ {
+ if (inst != diffBottom)
+ {
+ inst->replaceUsesWith(diffBottom);
+ inst->removeAndDeallocate();
+ modified = true;
+ }
+ }
+ });
+ // Go through all uses of diffBottom and run simplification.
+ processAllInsts([&](IRInst* inst)
+ {
+ if (!inst->hasUses())
+ return;
+
+ builder->setInsertBefore(inst);
+ IRInst* valueToReplace = nullptr;
+ switch (inst->getOp())
+ {
+ case kIROp_Store:
+ if (as<IRStore>(inst)->getVal() == diffBottom)
+ {
+ inst->removeAndDeallocate();
+ changed = true;
+ }
+ return;
+ case kIROp_MakeDifferentialPair:
+ // Our simplification could lead to a situation where
+ // bottom is used to make a pair that has a non-bottom differential type,
+ // in this case we should use zero instead.
+ if (inst->getOperand(1) == diffBottom)
+ {
+ // Only apply if we are the second operand.
+ auto pairType = as<IRDifferentialPairType>(inst->getDataType());
+ if (pairBuilderStorage.getDiffTypeFromPairType(builder, pairType)->getOp() != kIROp_DifferentialBottomType)
+ {
+ auto zero = transcriberStorage.getDifferentialZeroOfType(builder, pairType->getValueType());
+ inst->setOperand(1, zero);
+ changed = true;
+ }
+ }
+ return;
+ case kIROp_DifferentialPairGetDifferential:
+ if (inst->getOperand(0)->getOp() == kIROp_MakeDifferentialPair)
+ {
+ valueToReplace = inst->getOperand(0)->getOperand(1);
+ }
+ break;
+ case kIROp_DifferentialPairGetPrimal:
+ if (inst->getOperand(0)->getOp() == kIROp_MakeDifferentialPair)
+ {
+ valueToReplace = inst->getOperand(0)->getOperand(0);
+ }
+ break;
+ case kIROp_Add:
+ if (inst->getOperand(0) == diffBottom)
+ {
+ valueToReplace = inst->getOperand(1);
+ }
+ else if (inst->getOperand(1) == diffBottom)
+ {
+ valueToReplace = inst->getOperand(0);
+ }
+ break;
+ case kIROp_Sub:
+ if (inst->getOperand(0) == diffBottom)
+ {
+ // If left is bottom, and right is not bottom, then we should return -right.
+ // However we can't possibly run into that case since both side of - operator
+ // must be at the same order of differentiation.
+ valueToReplace = diffBottom;
+ }
+ else if (inst->getOperand(1) == diffBottom)
+ {
+ valueToReplace = inst->getOperand(0);
+ }
+ break;
+ case kIROp_Mul:
+ case kIROp_Div:
+ if (inst->getOperand(0) == diffBottom)
+ {
+ valueToReplace = diffBottom;
+ }
+ else if (inst->getOperand(1) == diffBottom)
+ {
+ valueToReplace = diffBottom;
+ }
+ break;
+ default:
+ break;
+ }
+ if (valueToReplace)
+ {
+ inst->replaceUsesWith(valueToReplace);
+ changed = true;
+ }
+ });
+ modified |= changed;
}
return modified;
}
+ bool eliminateDifferentialBottomType(IRBuilder* builder)
+ {
+ simplifyDifferentialBottomType(builder);
+
+ bool modified = false;
+ auto diffBottom = builder->getDifferentialBottom();
+ auto diffBottomType = diffBottom->getDataType();
+ diffBottom->replaceUsesWith(builder->getVoidValue());
+ diffBottom->removeAndDeallocate();
+ diffBottomType->replaceUsesWith(builder->getVoidType());
+
+ return modified;
+ }
+
// Checks decorators to see if the function should
// be differentiated (kIROp_ForwardDifferentiableDecoration)
//
@@ -2074,27 +2430,18 @@ struct JVPDerivativeContext
}
JVPDerivativeContext(IRModule* module, DiagnosticSink* sink) :
- module(module),
+ InstPassBase(module),
sink(sink),
autoDiffSharedContextStorage(module->getModuleInst()),
- transcriberStorage(&autoDiffSharedContextStorage)
+ transcriberStorage(&autoDiffSharedContextStorage, &sharedBuilderStorage)
{
+ pairBuilderStorage.sharedContext = &autoDiffSharedContextStorage;
transcriberStorage.sink = sink;
transcriberStorage.autoDiffSharedContext = &(autoDiffSharedContextStorage);
transcriberStorage.pairBuilder = &(pairBuilderStorage);
}
- protected:
-
- // This type passes over the module and generates
- // forward-mode derivative versions of functions
- // that are explicitly marked for it.
- //
- IRModule* module;
-
- // Shared builder state for our derivative passes.
- SharedIRBuilder sharedBuilderStorage;
-
+protected:
// A transcriber object that handles the main job of
// processing instructions while maintaining state.
//
@@ -2104,10 +2451,6 @@ struct JVPDerivativeContext
// error messages.
DiagnosticSink* sink;
- // Work queue to hold a stream of instructions that need
- // to be checked for references to derivative functions.
- IRWorkQueue workQueueStorage;
-
// Context to find and manage the witness tables for types
// implementing `IDifferentiable`
AutoDiffSharedContext autoDiffSharedContextStorage;
diff --git a/source/slang/slang-ir-inst-pass-base.h b/source/slang/slang-ir-inst-pass-base.h
index 2e251e46d..b5a1f168a 100644
--- a/source/slang/slang-ir-inst-pass-base.h
+++ b/source/slang/slang-ir-inst-pass-base.h
@@ -25,6 +25,17 @@ namespace Slang
workListSet.Add(inst);
}
+ IRInst* pop()
+ {
+ if (workList.getCount() == 0)
+ return nullptr;
+
+ IRInst* inst = workList.getLast();
+ workList.removeLast();
+ workListSet.Remove(inst);
+ return inst;
+ }
+
public:
InstPassBase(IRModule* inModule)
: module(inModule)
@@ -40,10 +51,8 @@ namespace Slang
while (workList.getCount() != 0)
{
- IRInst* inst = workList.getLast();
+ IRInst* inst = pop();
- workList.removeLast();
- workListSet.Remove(inst);
if (inst->getOp() == instOp)
{
f(as<InstType>(inst));
@@ -66,10 +75,7 @@ namespace Slang
while (workList.getCount() != 0)
{
- IRInst* inst = workList.getLast();
-
- workList.removeLast();
- workListSet.Remove(inst);
+ IRInst* inst = pop();
if (inst->getOp() == instOp)
{
f(as<InstType>(inst));
@@ -92,10 +98,8 @@ namespace Slang
while (workList.getCount() != 0)
{
- IRInst* inst = workList.getLast();
+ IRInst* inst = pop();
- workList.removeLast();
- workListSet.Remove(inst);
f(inst);
for (auto child = inst->getLastChild(); child; child = child->getPrevInst())
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 989777944..1d1e2ae69 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -603,7 +603,7 @@ struct IRDifferentiableTypeDictionaryItem : IRInst
IRInst* getWitness() { return getOperand(1); }
};
-struct IRDifferentiableTypeDictionaryDecoration : IRInst
+struct IRDifferentiableTypeDictionaryDecoration : IRDecoration
{
IR_LEAF_ISA(DifferentiableTypeDictionaryDecoration)
};
@@ -2301,6 +2301,7 @@ public:
IRInst* getBoolValue(bool value);
IRInst* getIntValue(IRType* type, IRIntegerValue value);
IRInst* getFloatValue(IRType* type, IRFloatingPointValue value);
+ IRInst* getDifferentialBottom();
IRStringLit* getStringValue(const UnownedStringSlice& slice);
IRPtrLit* _getPtrValue(void* ptr);
IRPtrLit* getNullPtrValue(IRType* type);
@@ -2330,6 +2331,7 @@ public:
IRAnyValueType* getAnyValueType(IRIntegerValue size);
IRAnyValueType* getAnyValueType(IRInst* size);
IRDynamicType* getDynamicType();
+ IRDifferentialBottomType* getDifferentialBottomType();
IRTupleType* getTupleType(UInt count, IRType* const* types);
IRTupleType* getTupleType(List<IRType*> const& types)
@@ -2388,7 +2390,7 @@ public:
IRDifferentialPairType* getDifferentialPairType(
IRType* valueType,
- IRWitnessTable* witnessTable);
+ IRInst* witnessTable);
IRFuncType* getFuncType(
UInt paramCount,
@@ -2600,6 +2602,8 @@ public:
IRInst* emitGetOptionalValue(IRInst* optValue);
IRInst* emitMakeOptionalValue(IRInst* optType, IRInst* value);
IRInst* emitMakeOptionalNone(IRInst* optType, IRInst* defaultValue);
+ IRInst* emitDifferentialPairGetDifferential(IRType* diffType, IRInst* diffPair);
+ IRInst* emitDifferentialPairGetPrimal(IRInst* diffPair);
IRInst* emitMakeVector(
IRType* type,
UInt argCount,
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 083ef98c5..f9686ac5b 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -1967,6 +1967,7 @@ namespace Slang
return getStringSlice() == rhs->getStringSlice();
}
case kIROp_VoidLit:
+ case kIROp_DifferentialBottomValue:
{
return true;
}
@@ -2009,6 +2010,7 @@ namespace Slang
return combineHash(code, Slang::getHashCode(slice.begin(), slice.getLength()));
}
case kIROp_VoidLit:
+ case kIROp_DifferentialBottomValue:
{
return code;
}
@@ -2074,12 +2076,20 @@ namespace Slang
}
case kIROp_VoidLit:
{
- const size_t instSize = prefixSize;
+ const size_t instSize = prefixSize + sizeof(void*);
irValue = static_cast<IRConstant*>(
_createInst(instSize, keyInst.getFullType(), keyInst.getOp()));
irValue->value.ptrVal = keyInst.value.ptrVal;
break;
}
+ case kIROp_DifferentialBottomValue:
+ {
+ const size_t instSize = prefixSize + sizeof(void*);
+ irValue = static_cast<IRConstant*>(
+ _createInst(instSize, keyInst.getFullType(), keyInst.getOp()));
+ irValue->value.ptrVal = nullptr;
+ break;
+ }
case kIROp_StringLit:
{
const UnownedStringSlice slice = keyInst.getStringSlice();
@@ -2182,6 +2192,17 @@ namespace Slang
return _findOrEmitConstant(keyInst);
}
+ IRInst* IRBuilder::getDifferentialBottom()
+ {
+ IRType* type = getDifferentialBottomType();
+ IRConstant keyInst;
+ memset(&keyInst, 0, sizeof(keyInst));
+ keyInst.m_op = kIROp_DifferentialBottomValue;
+ keyInst.typeUse.usedValue = type;
+ keyInst.value.intVal = 0;
+ return (IRInst*)_findOrEmitConstant(keyInst);
+ }
+
IRStringLit* IRBuilder::getStringValue(const UnownedStringSlice& inSlice)
{
IRConstant keyInst;
@@ -2564,6 +2585,12 @@ namespace Slang
IRDynamicType* IRBuilder::getDynamicType() { return (IRDynamicType*)getType(kIROp_DynamicType); }
+ IRDifferentialBottomType* IRBuilder::getDifferentialBottomType()
+ {
+ return (IRDifferentialBottomType*)getType(kIROp_DifferentialBottomType);
+ }
+
+
IRAssociatedType* IRBuilder::getAssociatedType(ArrayView<IRInterfaceType*> constraintTypes)
{
return (IRAssociatedType*)getType(kIROp_AssociatedType,
@@ -2760,7 +2787,7 @@ namespace Slang
IRDifferentialPairType* IRBuilder::getDifferentialPairType(
IRType* valueType,
- IRWitnessTable* witnessTable)
+ IRInst* witnessTable)
{
IRInst* operands[] = { valueType, witnessTable };
return (IRDifferentialPairType*)getType(
@@ -3389,6 +3416,25 @@ namespace Slang
return emitIntrinsicInst(type, kIROp_makeVector, argCount, args);
}
+ IRInst* IRBuilder::emitDifferentialPairGetDifferential(IRType* diffType, IRInst* diffPair)
+ {
+ return emitIntrinsicInst(
+ diffType,
+ kIROp_DifferentialPairGetDifferential,
+ 1,
+ &diffPair);
+ }
+
+ IRInst* IRBuilder::emitDifferentialPairGetPrimal(IRInst* diffPair)
+ {
+ auto valueType = as<IRDifferentialPairType>(diffPair->getDataType())->getValueType();
+ return emitIntrinsicInst(
+ valueType,
+ kIROp_DifferentialPairGetPrimal,
+ 1,
+ &diffPair);
+ }
+
IRInst* IRBuilder::emitMakeMatrix(
IRType* type,
UInt argCount,
diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h
index 9295ca2f5..59a61958d 100644
--- a/source/slang/slang-ir.h
+++ b/source/slang/slang-ir.h
@@ -861,6 +861,8 @@ SIMPLE_IR_TYPE(NativeStringType, StringTypeBase)
SIMPLE_IR_TYPE(DynamicType, Type)
+SIMPLE_IR_TYPE(DifferentialBottomType, Type)
+
// True if types are equal
// Note compares nominal types by name alone
bool isTypeEqual(IRType* a, IRType* b);
diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp
index 980a1d0bc..78edd4deb 100644
--- a/source/slang/slang-parser.cpp
+++ b/source/slang/slang-parser.cpp
@@ -6430,6 +6430,16 @@ namespace Slang
return modifier;
}
+ static NodeBase* parseBuiltinRequirementModifier(Parser* parser, void* /*userData*/)
+ {
+ BuiltinRequirementModifier* modifier = parser->astBuilder->create<BuiltinRequirementModifier>();
+ parser->ReadToken(TokenType::LParent);
+ modifier->kind = BuiltinRequirementKind(StringToInt(parser->ReadToken(TokenType::IntegerLiteral).getContent()));
+ parser->ReadToken(TokenType::RParent);
+
+ return modifier;
+ }
+
static NodeBase* parseMagicTypeModifier(Parser* parser, void* /*userData*/)
{
MagicTypeModifier* modifier = parser->astBuilder->create<MagicTypeModifier>();
@@ -6618,6 +6628,8 @@ namespace Slang
_makeParseModifier("__cuda_sm_version", parseCUDASMVersionModifier),
_makeParseModifier("__builtin_type", parseBuiltinTypeModifier),
+ _makeParseModifier("__builtin_requirement", parseBuiltinRequirementModifier),
+
_makeParseModifier("__magic_type", parseMagicTypeModifier),
_makeParseModifier("__intrinsic_type", parseIntrinsicTypeModifier),
_makeParseModifier("__implicit_conversion", parseImplicitConversionModifier),
diff --git a/source/slang/slang-syntax.cpp b/source/slang/slang-syntax.cpp
index 8cd443438..12b9dab42 100644
--- a/source/slang/slang-syntax.cpp
+++ b/source/slang/slang-syntax.cpp
@@ -325,7 +325,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt
// coerce to `DifferentialBottom`.
if (astBuilder->getDifferentialBottomType()->equals(transitiveTypeWitness->subToMid->sup))
{
- if (auto builtinAttr = requirementKey->findModifier<BuiltinRequirementAttribute>())
+ if (auto builtinAttr = requirementKey->findModifier<BuiltinRequirementModifier>())
{
if (builtinAttr->kind == BuiltinRequirementKind::DifferentialType)
{