diff options
| author | Yong He <yonghe@outlook.com> | 2022-11-04 09:36:23 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-11-04 09:36:23 -0700 |
| commit | c6e6b7a9177bf4f7fc2f05da36c5952979006d78 (patch) | |
| tree | 6db694b5b4bf94ce48678c73921676f9d305614d /source/slang | |
| parent | 015bde8d5a46f32979c00dbb1feb4b3d80729c44 (diff) | |
Higher order differentiation. (#2487)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang')
| -rw-r--r-- | source/slang/core.meta.slang | 59 | ||||
| -rw-r--r-- | source/slang/diff.meta.slang | 58 | ||||
| -rw-r--r-- | source/slang/slang-ast-builder.cpp | 5 | ||||
| -rw-r--r-- | source/slang/slang-ast-builder.h | 8 | ||||
| -rw-r--r-- | source/slang/slang-ast-modifier.h | 17 | ||||
| -rw-r--r-- | source/slang/slang-ast-support-types.cpp | 6 | ||||
| -rw-r--r-- | source/slang/slang-ast-support-types.h | 1 | ||||
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 137 | ||||
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-check-impl.h | 16 | ||||
| -rw-r--r-- | source/slang/slang-check-modifier.cpp | 21 | ||||
| -rw-r--r-- | source/slang/slang-diagnostic-defs.h | 3 | ||||
| -rw-r--r-- | source/slang/slang-emit.cpp | 3 | ||||
| -rw-r--r-- | source/slang/slang-ir-cleanup-void.cpp | 49 | ||||
| -rw-r--r-- | source/slang/slang-ir-diff-jvp.cpp | 835 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-pass-base.h | 24 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 8 | ||||
| -rw-r--r-- | source/slang/slang-ir.cpp | 50 | ||||
| -rw-r--r-- | source/slang/slang-ir.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-parser.cpp | 12 | ||||
| -rw-r--r-- | source/slang/slang-syntax.cpp | 2 |
21 files changed, 924 insertions, 394 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index ce52dbb56..05963bd11 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -105,6 +105,55 @@ interface __BuiltinSignedArithmeticType : __BuiltinArithmeticType {} interface __BuiltinIntegerType : __BuiltinArithmeticType {} + +/// Interface to denote types as differentiable. +/// Allows for user-specified differential types as +/// well as automatic generation, for when the associated type +/// hasn't been declared explicitly. +/// Note that the requirements must currently be defined in this exact order +/// since the auto-diff pass relies on the order to grab the struct keys. +/// +__magic_type(DifferentiableType) +interface IDifferentiable +{ + // Note: the compiler implementation requires the `Differential` associated type to be defined + // before anything else. + + __builtin_requirement($( (int) BuiltinRequirementKind::DifferentialType) ) + associatedtype Differential : IDifferentiable; + + __builtin_requirement($( (int)BuiltinRequirementKind::DZeroFunc) ) + static Differential dzero(); + + __builtin_requirement($( (int)BuiltinRequirementKind::DAddFunc) ) + static Differential dadd(Differential, Differential); + + __builtin_requirement($( (int)BuiltinRequirementKind::DMulFunc) ) + static Differential dmul(This, Differential); +}; + +__magic_type(DifferentialBottomType) +__intrinsic_type($(kIROp_DifferentialBottomType)) +struct __DifferentialBottom : IDifferentiable +{ + typedef __DifferentialBottom Differential; + + __intrinsic_op($(kIROp_DifferentialBottomValue)) + static __DifferentialBottom dzero(); + + [__unsafeForceInlineEarly] + static __DifferentialBottom dadd(Differential a, Differential b) + { + return dzero(); + } + + [__unsafeForceInlineEarly] + static __DifferentialBottom dmul(This a, Differential b) + { + return dzero(); + } +} + /// A type that can represent non-integers [sealed] [builtin] @@ -2739,16 +2788,6 @@ attribute_syntax [Specialize] : SpecializeAttribute; __attributeTarget(DeclBase) attribute_syntax [DerivativeMember(memberName)] : DerivativeMemberAttribute; -enum _BuiltinRequirementKind -{ - DifferentialType = $( (int) BuiltinRequirementKind::DifferentialType), - DZeroFunc = $( (int) BuiltinRequirementKind::DZeroFunc), - DAddFunc = $( (int) BuiltinRequirementKind::DAddFunc), - DMulFunc = $( (int) BuiltinRequirementKind::DMulFunc), -}; -__attributeTarget(DeclBase) -attribute_syntax [__BuiltinRequirement(kind: _BuiltinRequirementKind)] : BuiltinRequirementAttribute; - __attributeTarget(DeclBase) attribute_syntax [builtin] : BuiltinAttribute; diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index 1c3066e1d..ae4db603e 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -9,32 +9,6 @@ attribute_syntax [ForwardDifferentiable] : ForwardDifferentiableAttribute; __attributeTarget(FunctionDeclBase) attribute_syntax [ForwardDerivative(function)] : ForwardDerivativeAttribute; -/// Interface to denote types as differentiable. -/// Allows for user-specified differential types as -/// well as automatic generation, for when the associated type -/// hasn't been declared explicitly. -/// Note that the requirements must currently be defined in this exact order -/// since the auto-diff pass relies on the order to grab the struct keys. -/// -__magic_type(DifferentiableType) -interface IDifferentiable -{ - // Note: the compiler implementation requires the `Differential` associated type to be defined - // before anything else. - - [__BuiltinRequirement(_BuiltinRequirementKind.DifferentialType)] - associatedtype Differential; - - [__BuiltinRequirement(_BuiltinRequirementKind.DZeroFunc)] - static Differential dzero(); - - [__BuiltinRequirement(_BuiltinRequirementKind.DAddFunc)] - static Differential dadd(Differential, Differential); - - [__BuiltinRequirement(_BuiltinRequirementKind.DMulFunc)] - static Differential dmul(This, Differential); -}; - // Add extensions for the standard types extension float : IDifferentiable { @@ -83,28 +57,6 @@ extension vector<float, N> : IDifferentiable } } -__magic_type(DifferentialBottomType) -__intrinsic_type($(kIROp_DifferentialBottomType)) -struct __DifferentialBottom : IDifferentiable -{ - typedef __DifferentialBottom Differential; - - __intrinsic_op($(kIROp_DifferentialBottomValue)) - static __DifferentialBottom dzero(); - - [__unsafeForceInlineEarly] - static __DifferentialBottom dadd(Differential a, Differential b) - { - return dzero(); - } - - [__unsafeForceInlineEarly] - static __DifferentialBottom dmul(This a, Differential b) - { - return dzero(); - } -} - /// Pair type that serves to wrap the primal and /// differential types of an arbitrary type T. __generic<T : IDifferentiable> @@ -121,6 +73,7 @@ struct DifferentialPair : IDifferentiable __intrinsic_op($(kIROp_DifferentialPairGetDifferential)) T.Differential d(); + [__unsafeForceInlineEarly] T.Differential getDifferential() { return d(); @@ -129,6 +82,7 @@ struct DifferentialPair : IDifferentiable __intrinsic_op($(kIROp_DifferentialPairGetPrimal)) T p(); + [__unsafeForceInlineEarly] T getPrimal() { return p(); @@ -137,7 +91,7 @@ struct DifferentialPair : IDifferentiable [__unsafeForceInlineEarly] static Differential dzero() { - return Differential(T.dzero(), Differential.DifferentialElementType.dzero()); + return Differential(T.dzero(), T.Differential.dzero()); } [__unsafeForceInlineEarly] @@ -148,15 +102,15 @@ struct DifferentialPair : IDifferentiable a.p(), b.p() ), - Differential.DifferentialElementType.dzero()); + T.Differential.dadd(a.d(), b.d())); } [__unsafeForceInlineEarly] static Differential dmul(This a, Differential b) { return Differential( - T.dmul(a.p(), b.p()), - Differential.DifferentialElementType.dzero()); + T.dmul(a.p(), b.p()), + T.Differential.dmul(a.d(), b.d())); } }; diff --git a/source/slang/slang-ast-builder.cpp b/source/slang/slang-ast-builder.cpp index beee16f9c..6249d7825 100644 --- a/source/slang/slang-ast-builder.cpp +++ b/source/slang/slang-ast-builder.cpp @@ -171,6 +171,11 @@ void SharedASTBuilder::registerBuiltinDecl(Decl* decl, BuiltinTypeModifier* modi m_builtinTypes[Index(modifier->tag)] = type; } +void SharedASTBuilder::registerBuiltinRequirementDecl(Decl* decl, BuiltinRequirementModifier* modifier) +{ + m_builtinRequirementDecls[modifier->kind] = decl; +} + void SharedASTBuilder::registerMagicDecl(Decl* decl, MagicTypeModifier* modifier) { // In some cases the modifier will have been applied to the diff --git a/source/slang/slang-ast-builder.h b/source/slang/slang-ast-builder.h index 235bebfaa..190e3727d 100644 --- a/source/slang/slang-ast-builder.h +++ b/source/slang/slang-ast-builder.h @@ -19,6 +19,7 @@ class SharedASTBuilder : public RefObject public: void registerBuiltinDecl(Decl* decl, BuiltinTypeModifier* modifier); + void registerBuiltinRequirementDecl(Decl* decl, BuiltinRequirementModifier* modifier); void registerMagicDecl(Decl* decl, MagicTypeModifier* modifier); /// Get the string type @@ -49,6 +50,11 @@ public: Decl* tryFindMagicDecl(String const& name); + Decl* findBuiltinRequirementDecl(BuiltinRequirementKind kind) + { + return m_builtinRequirementDecls[kind].GetValue(); + } + /// A name pool that can be used for lookup for findClassInfo etc. It is the same pool as the Session. NamePool* getNamePool() { return m_namePool; } @@ -85,6 +91,7 @@ protected: Type* m_builtinTypes[Index(BaseType::CountOf)]; Dictionary<String, Decl*> m_magicDecls; + Dictionary<BuiltinRequirementKind, Decl*> m_builtinRequirementDecls; Dictionary<UnownedStringSlice, const ReflectClassInfo*> m_sliceToTypeMap; Dictionary<Name*, const ReflectClassInfo*> m_nameToTypeMap; @@ -334,6 +341,7 @@ public: Witness* primalIsDifferentialWitness); DeclRef<InterfaceDecl> getDifferentiableInterface(); + Decl* getDifferentiableAssociatedTypeRequirement(); bool isDifferentiableInterfaceAvailable(); diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index 67ff297dc..57dfbac9e 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -395,6 +395,15 @@ class MagicTypeModifier : public Modifier uint32_t tag = uint32_t(0); }; +// A modifier that indicates a built-in associated type requirement (e.g., `Differential`) +class BuiltinRequirementModifier : public Modifier +{ + SLANG_AST_CLASS(BuiltinRequirementModifier); + + BuiltinRequirementKind kind; +}; + + // A modifier applied to declarations of builtin types to indicate how they // should be lowered to the IR. // @@ -590,14 +599,6 @@ class Attribute : public AttributeBase AttributeArgumentValueDict intArgVals; }; -// A modifier that indicates a built-in associated type requirement (e.g., `Differential`) -class BuiltinRequirementAttribute : public Attribute -{ - SLANG_AST_CLASS(BuiltinRequirementAttribute); - - BuiltinRequirementKind kind; -}; - class UserDefinedAttribute : public Attribute { SLANG_AST_CLASS(UserDefinedAttribute) diff --git a/source/slang/slang-ast-support-types.cpp b/source/slang/slang-ast-support-types.cpp index badb524bb..7133f2a65 100644 --- a/source/slang/slang-ast-support-types.cpp +++ b/source/slang/slang-ast-support-types.cpp @@ -2,7 +2,9 @@ #include "slang-ast-base.h" #include "slang-ast-type.h" -Slang::QualType::QualType(Type* type) +namespace Slang +{ +QualType::QualType(Type* type) : type(type) , isLeftValue(false) { @@ -11,3 +13,5 @@ Slang::QualType::QualType(Type* type) isLeftValue = true; } } + +} diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h index 015e6969c..d4a781846 100644 --- a/source/slang/slang-ast-support-types.h +++ b/source/slang/slang-ast-support-types.h @@ -1510,6 +1510,7 @@ namespace Slang DAddFunc, ///< The `IDifferentiable.dadd` function requirement DMulFunc, ///< The `IDifferentiable.dmul` function requirement }; + } // namespace Slang #endif diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 7140d541a..333e9d973 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -254,6 +254,8 @@ namespace Slang void visitFunctionDeclBase(FunctionDeclBase* funcDecl); void visitParamDecl(ParamDecl* paramDecl); + + void _maybeRegisterDifferentialBottomTypeConformance(SemanticsContext& context); }; /// Should the given `decl` nested in `parentDecl` be treated as a static rather than instance declaration? @@ -1433,6 +1435,22 @@ namespace Slang synth.pushScopeForContainer(aggTypeDecl); } + // If `This` is nested inside a generic, we need to form a complete declref type to the + // newly synthesized aggTypeDecl here. This can be done by obtaining ThisTypeSubstitution + // from requirementDeclRef to get the generic substitution for outer generic parameters, and + // apply it to the newly synthesized decl. + SubstitutionSet substSet; + if (auto thisTypeSusbt = findThisTypeSubstitution( + requirementDeclRef.substitutions, + as<InterfaceDecl>(requirementDeclRef.getDecl()->parentDecl))) + { + if (auto declRefType = as<DeclRefType>(thisTypeSusbt->witness->sub)) + { + substSet = declRefType->declRef.substitutions; + } + } + auto satisfyingType = m_astBuilder->getOrCreateDeclRefType(aggTypeDecl, substSet); + // Helper function to add a `diffType` field into the synthesized type for the original // `member`. auto differentialType = DeclRefType::create(m_astBuilder, makeDeclRef(aggTypeDecl)); @@ -1462,6 +1480,22 @@ namespace Slang addModifier(member, derivativeMemberModifier); }; + // Make the Differential type itself conform to `IDifferential` interface. + auto inheritanceIDiffernetiable = m_astBuilder->create<InheritanceDecl>(); + inheritanceIDiffernetiable->base.type = + DeclRefType::create(m_astBuilder, m_astBuilder->getDifferentiableInterface()); + inheritanceIDiffernetiable->parentDecl = aggTypeDecl; + aggTypeDecl->members.add(inheritanceIDiffernetiable); + + // The `Differential` type of a `Differential` type is always itself. + auto assocTypeDef = m_astBuilder->create<TypeDefDecl>(); + assocTypeDef->nameAndLoc.name = getName("Differential"); + assocTypeDef->type.type = satisfyingType; + assocTypeDef->parentDecl = aggTypeDecl; + assocTypeDef->setCheckState(DeclCheckState::Checked); + aggTypeDecl->members.add(assocTypeDef); + + // Go through all members and collect their differential types. // Go through super types. for (auto inheritance : context->parentDecl->getMembersOfType<InheritanceDecl>()) { @@ -1476,8 +1510,7 @@ namespace Slang } } } - - // We go through all members and generate their differential counterparts. + // Go through all var members. for (auto member : context->parentDecl->getMembersOfType<VarDeclBase>()) { auto diffType = tryGetDifferentialType(m_astBuilder, member->type.type); @@ -1488,22 +1521,9 @@ namespace Slang addModifier(aggTypeDecl, m_astBuilder->create<SynthesizedModifier>()); - // If `This` is nested inside a generic, we need to form a complete declref type to the - // newly synthesized aggTypeDecl here. This can be done by obtaining ThisTypeSubstitution - // from requirementDeclRef to get the generic substitution for outer generic parameters, and - // apply it to the newly synthesized decl. - SubstitutionSet substSet; - if (auto thisTypeSusbt = findThisTypeSubstitution( - requirementDeclRef.substitutions, - as<InterfaceDecl>(requirementDeclRef.getDecl()->parentDecl))) - { - if (auto declRefType = as<DeclRefType>(thisTypeSusbt->witness->sub)) - { - substSet = declRefType->declRef.substitutions; - } - } - - auto satisfyingType = m_astBuilder->getOrCreateDeclRefType(aggTypeDecl, substSet); + // Synthesize the rest of IDifferential method conformances by recursively checking + // conformance on the synthesized decl. + checkAggTypeConformance(aggTypeDecl); if (doesTypeSatisfyAssociatedTypeConstraintRequirement(satisfyingType, requirementDeclRef, witnessTable)) { @@ -1616,6 +1636,50 @@ namespace Slang } }; + // Check that types used as `Differential` type use themselves as their own `Differential` type. + struct SemanticsDeclDifferentialConformanceVisitor + : public SemanticsDeclVisitorBase + , public DeclVisitor<SemanticsDeclDifferentialConformanceVisitor> + { + SemanticsDeclDifferentialConformanceVisitor(SemanticsContext const& outer) + : SemanticsDeclVisitorBase(outer) + {} + void visitDecl(Decl*) {} + void visitDeclGroup(DeclGroup*) {} + + void visitInheritanceDecl(InheritanceDecl* inheritanceDecl) + { + if (as<InterfaceDecl>(inheritanceDecl->parentDecl)) + return; + + if (!inheritanceDecl->witnessTable) + return; + auto baseType = as<DeclRefType>(inheritanceDecl->witnessTable->baseType); + if (!baseType) + return; + if (baseType->declRef.getDecl() != m_astBuilder->getDifferentiableInterface().getDecl()) + return; + RequirementWitness witnessValue; + auto requirementDecl = m_astBuilder->getSharedASTBuilder()->findBuiltinRequirementDecl(BuiltinRequirementKind::DifferentialType); + if (!inheritanceDecl->witnessTable->requirementDictionary.TryGetValue(requirementDecl, witnessValue)) + return; + + // A type used as differential type must have itself as its own differential type. + if (witnessValue.getFlavor() != RequirementWitness::Flavor::val) + return; + auto differentialType = as<DeclRefType>(witnessValue.getVal()); + if (!differentialType) + return; + auto diffDiffType = tryGetDifferentialType(m_astBuilder, differentialType); + if (!differentialType->equals(diffDiffType)) + { + SourceLoc sourceLoc = differentialType->declRef.getDecl()->loc; + getSink()->diagnose(sourceLoc, Diagnostics::differentialTypeShouldServeAsItsOwnDifferentialType, differentialType); + getSink()->diagnose(inheritanceDecl, Diagnostics::noteSeeUseOfDifferentialType, differentialType, inheritanceDecl->getSup()); + } + } + }; + /// Recursively register any builtin declarations that need to be attached to the `session`. /// /// This function should only be needed for declarations in the standard library. @@ -1632,7 +1696,10 @@ namespace Slang { sharedASTBuilder->registerMagicDecl(decl, magicMod); } - + if (auto builtinRequirement = decl->findModifier<BuiltinRequirementModifier>()) + { + sharedASTBuilder->registerBuiltinRequirementDecl(decl, builtinRequirement); + } if(auto containerDecl = as<ContainerDecl>(decl)) { for(auto childDecl : containerDecl->members) @@ -2217,13 +2284,14 @@ namespace Slang // associated type and see if they can be satisfied. // bool conformance = true; + Val* witness = nullptr; for (auto requiredConstraintDeclRef : getMembersOfType<TypeConstraintDecl>(requiredAssociatedTypeDeclRef)) { // Grab the type we expect to conform to from the constraint. auto requiredSuperType = getSup(m_astBuilder, requiredConstraintDeclRef); // Perform a search for a witness to the subtype relationship. - auto witness = tryGetSubtypeWitness(satisfyingType, requiredSuperType); + witness = tryGetSubtypeWitness(satisfyingType, requiredSuperType); if (witness) { // If a subtype witness was found, then the conformance @@ -3040,7 +3108,7 @@ namespace Slang witnessTable)) return true; - if (auto builtinAttr = requiredFuncDeclRef.getDecl()->findModifier<BuiltinRequirementAttribute>()) + if (auto builtinAttr = requiredFuncDeclRef.getDecl()->findModifier<BuiltinRequirementModifier>()) { switch (builtinAttr->kind) { @@ -3067,7 +3135,7 @@ namespace Slang if (auto requiredAssocTypeDeclRef = requiredMemberDeclRef.as<AssocTypeDecl>()) { - if (auto builtinAttr = requiredAssocTypeDeclRef.getDecl()->findModifier<BuiltinRequirementAttribute>()) + if (auto builtinAttr = requiredAssocTypeDeclRef.getDecl()->findModifier<BuiltinRequirementModifier>()) { switch (builtinAttr->kind) { @@ -3160,7 +3228,7 @@ namespace Slang bool hasDifferentialAssocType = false; for (auto existingEntry : witnessTable->requirementList) { - if (auto builtinReqAttr = existingEntry.Key->findModifier<BuiltinRequirementAttribute>()) + if (auto builtinReqAttr = existingEntry.Key->findModifier<BuiltinRequirementModifier>()) { if (builtinReqAttr->kind == BuiltinRequirementKind::DifferentialType && existingEntry.Value.getFlavor() != RequirementWitness::Flavor::none) @@ -3401,7 +3469,7 @@ namespace Slang // requirement, it may be possible that we can still synthesis the // implementation if this is one of the known builtin requirements. // Otherwise, report diagnostic now. - if (!requiredMemberDeclRef.getDecl()->hasModifier<BuiltinRequirementAttribute>()) + if (!requiredMemberDeclRef.getDecl()->hasModifier<BuiltinRequirementModifier>()) { getSink()->diagnose(inheritanceDecl, Diagnostics::typeDoesntImplementInterfaceRequirement, subType, requiredMemberDeclRef); getSink()->diagnose(requiredMemberDeclRef, Diagnostics::seeDeclarationOf, requiredMemberDeclRef); @@ -4499,11 +4567,29 @@ namespace Slang getSink()->diagnose(decl, Slang::Diagnostics::assocTypeInInterfaceOnly); } + void SemanticsDeclBodyVisitor::_maybeRegisterDifferentialBottomTypeConformance(SemanticsContext& context) + { + auto parentDifferentiableAttr = context.getParentDifferentiableAttribute(); + if (parentDifferentiableAttr) + { + auto diffBottomType = m_astBuilder->getDifferentialBottomType(); + auto idifferentiable = DeclRef<InterfaceDecl>(m_astBuilder->getDifferentiableInterface(), nullptr); + auto witness = as<SubtypeWitness>(tryGetInterfaceConformanceWitness(diffBottomType, idifferentiable)); + SLANG_ASSERT(witness); + parentDifferentiableAttr->m_mapTypeToIDifferentiableWitness.Add( + as<DeclRefType>(diffBottomType)->declRef, + witness); + } + } + void SemanticsDeclBodyVisitor::visitFunctionDeclBase(FunctionDeclBase* decl) { + auto newContext = withParentFunc(decl); + _maybeRegisterDifferentialBottomTypeConformance(newContext); + if (auto body = decl->body) { - checkBodyStmt(body, decl); + checkStmt(decl->body, newContext); } } @@ -6234,6 +6320,7 @@ namespace Slang case DeclCheckState::TypesFullyResolved: SemanticsDeclTypeResolutionVisitor(shared).dispatch(decl); + SemanticsDeclDifferentialConformanceVisitor(shared).dispatch(decl); break; case DeclCheckState::Checked: diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index ad199300a..09dd9eea1 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -428,7 +428,7 @@ namespace Slang // We will only ever need to synthesis a type to satisfy an associatedtype requirement. // In this case the lookup should have resolved to a known associatedtype decl. - auto builtinAssocTypeAttr = item.declRef.getDecl()->findModifier<BuiltinRequirementAttribute>(); + auto builtinAssocTypeAttr = item.declRef.getDecl()->findModifier<BuiltinRequirementModifier>(); if (!builtinAssocTypeAttr) return nullptr; diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index a0141911a..76918ebbe 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -214,6 +214,7 @@ namespace Slang Dictionary<OperatorOverloadCacheKey, OverloadCandidate> resolvedOperatorOverloadCache; Dictionary<BasicTypeKeyPair, ConversionCost> conversionCostCache; }; + /// Shared state for a semantics-checking session. struct SharedSemanticsContext { @@ -274,7 +275,6 @@ namespace Slang return m_linkage->isInLanguageServer(); return false; } - /// Get the list of extension declarations that appear to apply to `decl` in this context List<ExtensionDecl*> const& getCandidateExtensionsForTypeDecl(AggTypeDecl* decl); @@ -375,6 +375,11 @@ namespace Slang return result; } + DifferentiableAttribute* getParentDifferentiableAttribute() + { + return m_parentDifferentiableAttr; + } + /// A scope that is local to a particular expression, and /// that can be used to allocate temporary bindings that /// might be needed by that expression or its sub-expressions. @@ -1041,6 +1046,15 @@ namespace Slang DeclRef<AssocTypeDecl> requirementDeclRef, RefPtr<WitnessTable> witnessTable); + struct DifferentiableMemberInfo + { + Decl* memberDecl; + Type* diffType; + }; + + /// Gather differentiable members from decl. + List<DifferentiableMemberInfo> collectDifferentiableMemberInfo(ContainerDecl* decl); + // Find the appropriate member of a declared type to // satisfy a requirement of an interface the type // claims to conform to. diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp index 91f655a15..d8b05198c 100644 --- a/source/slang/slang-check-modifier.cpp +++ b/source/slang/slang-check-modifier.cpp @@ -484,27 +484,6 @@ namespace Slang return false; } } - else if (auto builtinAssocTypeAttr = as<BuiltinRequirementAttribute>(attr)) - { - if (attr->args.getCount() == 1) - { - //IntVal* outIntVal; - if (auto cInt = checkConstantEnumVal(attr->args[0])) - { - builtinAssocTypeAttr->kind = (BuiltinRequirementKind)(cInt->value); - } - else - { - getSink()->diagnose(attr, Diagnostics::expectedSingleIntArg, attr->keywordName); - return false; - } - } - else - { - getSink()->diagnose(attr, Diagnostics::expectedSingleIntArg, attr->keywordName); - return false; - } - } else if (auto unrollAttr = as<UnrollAttribute>(attr)) { // Check has an argument. We need this because default behavior is to give an error diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index 9e939e476..ffee0622c 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -303,6 +303,9 @@ DIAGNOSTIC(30093, Error, uncaughtTryCallInNonThrowFunc, "the current function or DIAGNOSTIC(30094, Error, mustUseTryClauseToCallAThrowFunc, "the callee may throw an error, and therefore must be called within a 'try' clause") DIAGNOSTIC(30095, Error, errorTypeOfCalleeIncompatibleWithCaller, "the error type `$1` of callee `$0` is not compatible with the caller's error type `$2`.") +DIAGNOSTIC(30096, Error, differentialTypeShouldServeAsItsOwnDifferentialType, "type '$0' is used as a `Differential` type, therefore it must serve as its own `Differential` type.") +DIAGNOSTIC(-1, Note, noteSeeUseOfDifferentialType, "see use of '$0' as Differential of '$1'.") + // Attributes DIAGNOSTIC(31000, Error, unknownAttributeName, "unknown attribute '$0'") DIAGNOSTIC(31001, Error, attributeArgumentCountMismatch, "attribute '$0' expects $1 arguments ($2 provided)") diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index fcdee78ea..9c72f1d63 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -381,9 +381,6 @@ Result linkAndOptimizeIR( // 2. Transpose JVP to VJP code wherever needed. (Transposition or "reverse-mode" pass) // processVJPDerivativeMarkers(module); // Disabled currently. No impl yet. - - // 3. Fill in higher-order invocations with the generated functions. - processDerivativeCalls(irModule); stripAutoDiffDecorations(irModule); diff --git a/source/slang/slang-ir-cleanup-void.cpp b/source/slang/slang-ir-cleanup-void.cpp index ac520c1d5..a72157a69 100644 --- a/source/slang/slang-ir-cleanup-void.cpp +++ b/source/slang/slang-ir-cleanup-void.cpp @@ -36,26 +36,26 @@ namespace Slang switch (inst->getOp()) { case kIROp_Call: + case kIROp_makeStruct: { // Remove void argument. - auto call = as<IRCall>(inst); List<IRInst*> newArgs; - for (UInt i = 0; i < call->getArgCount(); i++) + for (UInt i = 0; i < inst->getOperandCount(); i++) { - auto arg = call->getArg(i); + auto arg = inst->getOperand(i); if (arg->getDataType() && arg->getDataType()->getOp() == kIROp_VoidType) { continue; } newArgs.add(arg); } - if (newArgs.getCount() != (Index)call->getArgCount()) + if (newArgs.getCount() != (Index)inst->getOperandCount()) { IRBuilder builder(&sharedBuilderStorage); - builder.setInsertBefore(call); - auto newCall = builder.emitCallInst(call->getFullType(), call->getCallee(), newArgs); - call->replaceUsesWith(newCall); - call->removeAndDeallocate(); + builder.setInsertBefore(inst); + auto newCall = builder.emitIntrinsicInst(inst->getFullType(), inst->getOp(), newArgs.getCount(), newArgs.getBuffer()); + inst->replaceUsesWith(newCall); + inst->removeAndDeallocate(); inst = newCall; } } @@ -111,16 +111,43 @@ namespace Slang break; case kIROp_StructType: { - // TODO: cleanup void fields. + List<IRInst*> toRemove; + for (auto child : inst->getChildren()) + { + if (auto field = as<IRStructField>(child)) + { + if (field->getFieldType()->getOp() == kIROp_VoidType) + { + toRemove.add(field); + } + } + } + for (auto ii : toRemove) + ii->removeAndDeallocate(); } break; default: break; } - // TODO: If inst has void type, all uses of it should be replaced with void val. + // If inst has void type, all uses of it should be replaced with void val. // We should do this only for a subset of opcodes known to be safe. - + switch(inst->getOp()) + { + case kIROp_Load: + case kIROp_getElement: + case kIROp_GetOptionalValue: + case kIROp_FieldExtract: + case kIROp_GetTupleElement: + case kIROp_GetResultError: + case kIROp_GetResultValue: + if (inst->getDataType()->getOp() == kIROp_VoidType) + { + IRBuilder builder(&sharedBuilderStorage); + builder.setInsertBefore(inst); + inst->replaceUsesWith(builder.getVoidValue()); + } + } } void processModule() diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp index d0bf8f347..8a4fe23d0 100644 --- a/source/slang/slang-ir-diff-jvp.cpp +++ b/source/slang/slang-ir-diff-jvp.cpp @@ -7,6 +7,7 @@ #include "slang-ir-dce.h" #include "slang-ir-eliminate-phis.h" #include "slang-ir-util.h" +#include "slang-ir-inst-pass-base.h" // origX, primalX, diffX // origX -> primalX (cloneEnv) @@ -20,9 +21,19 @@ struct Pair { P primal; D differential; - + Pair() = default; Pair(P primal, D differential) : primal(primal), differential(differential) {} + HashCode getHashCode() const + { + Hasher hasher; + hasher << primal << differential; + return hasher.getResult(); + } + bool operator ==(const Pair& other) const + { + return primal == other.primal && differential == other.differential; + } }; typedef Pair<IRInst*, IRInst*> InstPair; @@ -43,6 +54,11 @@ struct AutoDiffSharedContext // IRStructKey* differentialAssocTypeStructKey = nullptr; + // The struct key for the witness that `Differential` associated type conforms to + // `IDifferential`. + IRStructKey* differentialAssocTypeWitnessStructKey = nullptr; + + // The struct key for the 'zero()' associated type // defined inside IDifferential. We use this to lookup the // implementation of zero() for a given type. @@ -54,6 +70,9 @@ struct AutoDiffSharedContext // implementation of add() for a given type. // IRStructKey* addMethodStructKey = nullptr; + + IRStructKey* mulMethodStructKey = nullptr; + // Modules that don't use differentiable types // won't have the IDifferentiable interface type available. @@ -69,8 +88,10 @@ struct AutoDiffSharedContext if (differentiableInterfaceType) { differentialAssocTypeStructKey = findDifferentialTypeStructKey(); + differentialAssocTypeWitnessStructKey = findDifferentialTypeWitnessStructKey(); zeroMethodStructKey = findZeroMethodStructKey(); addMethodStructKey = findAddMethodStructKey(); + mulMethodStructKey = findMulMethodStructKey(); if (differentialAssocTypeStructKey) isInterfaceAvailable = true; @@ -103,22 +124,32 @@ struct AutoDiffSharedContext return getIDifferentiableStructKeyAtIndex(0); } - IRStructKey* findZeroMethodStructKey() + IRStructKey* findDifferentialTypeWitnessStructKey() { return getIDifferentiableStructKeyAtIndex(1); } - IRStructKey* findAddMethodStructKey() + IRStructKey* findZeroMethodStructKey() { return getIDifferentiableStructKeyAtIndex(2); } + IRStructKey* findAddMethodStructKey() + { + return getIDifferentiableStructKeyAtIndex(3); + } + + IRStructKey* findMulMethodStructKey() + { + return getIDifferentiableStructKeyAtIndex(4); + } + IRStructKey* getIDifferentiableStructKeyAtIndex(UInt index) { if (as<IRModuleInst>(moduleInst) && differentiableInterfaceType) { - // Assume for now that IDifferentiable has exactly four fields. - SLANG_ASSERT(differentiableInterfaceType->getOperandCount() == 4); + // Assume for now that IDifferentiable has exactly five fields. + SLANG_ASSERT(differentiableInterfaceType->getOperandCount() == 5); if (auto entry = as<IRInterfaceRequirementEntry>(differentiableInterfaceType->getOperand(index))) return as<IRStructKey>(entry->getRequirementKey()); else @@ -300,7 +331,16 @@ struct DifferentialPairTypeBuilder IRInst* emitFieldAccessor(IRBuilder* builder, IRInst* baseInst, IRStructKey* key) { - if (auto basePairStructType = as<IRStructType>(baseInst->getDataType())) + auto baseTypeInfo = lowerDiffPairType(builder, baseInst->getDataType()); + if (baseTypeInfo.isTrivial) + { + if (key == globalPrimalKey) + return baseInst; + else + return builder->getDifferentialBottom(); + } + + if (auto basePairStructType = as<IRStructType>(baseTypeInfo.loweredType)) { return as<IRFieldExtract>(builder->emitFieldExtract( findField(basePairStructType, key)->getFieldType(), @@ -308,7 +348,7 @@ struct DifferentialPairTypeBuilder key )); } - else if (auto ptrType = as<IRPtrTypeBase>(baseInst->getDataType())) + else if (auto ptrType = as<IRPtrTypeBase>(baseTypeInfo.loweredType)) { if (auto ptrInnerSpecializedType = as<IRSpecialize>(ptrType->getValueType())) { @@ -334,7 +374,7 @@ struct DifferentialPairTypeBuilder key)); } } - else if (auto specializedType = as<IRSpecialize>(baseInst->getDataType())) + else if (auto specializedType = as<IRSpecialize>(baseTypeInfo.loweredType)) { // TODO: Stopped here -> The type being emitted is incorrect. don't emit the generic's // type, emit the specialization type. @@ -420,25 +460,64 @@ struct DifferentialPairTypeBuilder { SLANG_ASSERT(!as<IRParam>(origBaseType)); SLANG_ASSERT(diffType); - auto pairStructType = builder->createStructType(); - builder->createStructField(pairStructType, _getOrCreatePrimalStructKey(builder), origBaseType); - builder->createStructField(pairStructType, _getOrCreateDiffStructKey(builder), (IRType*)diffType); + if (diffType->getOp() != kIROp_DifferentialBottomType) + { + auto pairStructType = builder->createStructType(); + builder->createStructField(pairStructType, _getOrCreatePrimalStructKey(builder), origBaseType); + builder->createStructField(pairStructType, _getOrCreateDiffStructKey(builder), (IRType*)diffType); + return pairStructType; + } + return origBaseType; + } - return pairStructType; + struct LoweredPairTypeInfo + { + IRInst* loweredType; + bool isTrivial; + }; + + IRInst* getDiffTypeFromPairType(IRBuilder* builder, IRDifferentialPairType* type) + { + auto witnessTable = type->getWitness(); + return _lookupWitness(builder, witnessTable, sharedContext->differentialAssocTypeStructKey); } - IRInst* getOrCreateDiffPairType(IRBuilder* builder, IRType* origBaseType, IRType* diffType) + IRInst* getDiffTypeWitnessFromPairType(IRBuilder* builder, IRDifferentialPairType* type) { - if (pairTypeCache.ContainsKey(origBaseType)) - return pairTypeCache[origBaseType]; + auto witnessTable = type->getWitness(); + return _lookupWitness(builder, witnessTable, sharedContext->differentialAssocTypeWitnessStructKey); + } - auto pairType = _createDiffPairType(builder, origBaseType, diffType); - pairTypeCache.Add(origBaseType, pairType); + LoweredPairTypeInfo lowerDiffPairType(IRBuilder* builder, IRType* originalPairType) + { + LoweredPairTypeInfo result = {}; + + if (pairTypeCache.TryGetValue(originalPairType, result)) + return result; + auto pairType = as<IRDifferentialPairType>(originalPairType); + if (!pairType) + { + result.isTrivial = true; + result.loweredType = originalPairType; + return result; + } + auto primalType = pairType->getValueType(); + if (as<IRParam>(primalType)) + { + result.isTrivial = false; + result.loweredType = nullptr; + return result; + } + + auto diffType = getDiffTypeFromPairType(builder, pairType); + result.loweredType = _createDiffPairType(builder, pairType->getValueType(), (IRType*)diffType); + result.isTrivial = (diffType->getOp() == kIROp_DifferentialBottomType); + pairTypeCache.Add(originalPairType, result); - return pairType; + return result; } - Dictionary<IRInst*, IRInst*> pairTypeCache; + Dictionary<IRInst*, LoweredPairTypeInfo> pairTypeCache; IRStructKey* globalPrimalKey = nullptr; @@ -447,6 +526,8 @@ struct DifferentialPairTypeBuilder IRInst* genericDiffPairType = nullptr; List<IRInst*> generatedTypeList; + + AutoDiffSharedContext* sharedContext = nullptr; }; struct JVPTranscriber @@ -474,8 +555,15 @@ struct JVPTranscriber DifferentiableTypeConformanceContext differentiableTypeConformanceContext; - JVPTranscriber(AutoDiffSharedContext* shared) - : differentiableTypeConformanceContext(shared) + List<InstPair> followUpFunctionsToTranscribe; + + SharedIRBuilder* sharedBuilder; + // Witness table that `DifferentialBottom:IDifferential`. + IRWitnessTable* differentialBottomWitness = nullptr; + Dictionary<InstPair, IRInst*> differentialPairTypes; + + JVPTranscriber(AutoDiffSharedContext* shared, SharedIRBuilder* inSharedBuilder) + : differentiableTypeConformanceContext(shared), sharedBuilder(inSharedBuilder) {} DiagnosticSink* getSink() @@ -592,8 +680,75 @@ struct JVPTranscriber return builder->getFuncType(newParameterTypes, diffReturnType); } + IRWitnessTable* getDifferentialBottomWitness() + { + IRBuilder builder(sharedBuilder); + builder.setInsertInto(sharedBuilder->getModule()->getModuleInst()); + auto result = + as<IRWitnessTable>(differentiableTypeConformanceContext.lookUpConformanceForType( + builder.getDifferentialBottomType())); + SLANG_ASSERT(result); + return result; + } + + // Get or construct `:IDifferentiable` conformance for a DifferentiablePair. + IRWitnessTable* getDifferentialPairWitness(IRInst* inDiffPairType) + { + IRBuilder builder(sharedBuilder); + builder.setInsertInto(inDiffPairType->parent); + auto diffPairType = as<IRDifferentialPairType>(inDiffPairType); + SLANG_ASSERT(diffPairType); + auto result = + as<IRWitnessTable>(differentiableTypeConformanceContext.lookUpConformanceForType( + builder.getDifferentialBottomType())); + if (result) + return result; + + auto table = builder.createWitnessTable(autoDiffSharedContext->differentiableInterfaceType, diffPairType); + auto diffType = differentiateType(&builder, diffPairType->getValueType()); + auto differentialType = builder.getDifferentialPairType(diffType, getDifferentialBottomWitness()); + builder.createWitnessTableEntry(table, autoDiffSharedContext->differentialAssocTypeStructKey, differentialType); + // Omit the method synthesis here, since we can just intercept those directly at `getXXMethodForType`. + + differentiableTypeConformanceContext.differentiableWitnessDictionary[diffPairType] = table; + return table; + } + + IRType* getOrCreateDiffPairType(IRInst* primalType, IRInst* witness) + { + IRBuilder builder(sharedBuilder); + builder.setInsertInto(primalType->parent); + return builder.getDifferentialPairType( + (IRType*)primalType, + witness); + } + + IRType* getOrCreateDiffPairType(IRInst* primalType) + { + IRBuilder builder(sharedBuilder); + builder.setInsertInto(primalType->parent); + auto witness = as<IRWitnessTable>( + differentiableTypeConformanceContext.lookUpConformanceForType((IRType*)primalType)); + if (!witness) + witness = getDifferentialBottomWitness(); + return builder.getDifferentialPairType( + (IRType*)primalType, + witness); + } + IRType* differentiateType(IRBuilder* builder, IRType* origType) { + IRInst* diffType = nullptr; + if (!instMapD.TryGetValue(origType, diffType)) + { + diffType = _differentiateTypeImpl(builder, origType); + instMapD[origType] = diffType; + } + return (IRType*)diffType; + } + + IRType* _differentiateTypeImpl(IRBuilder* builder, IRType* origType) + { if (auto ptrType = as<IRPtrTypeBase>(origType)) return builder->getPtrType( origType->getOp(), @@ -628,6 +783,14 @@ struct JVPTranscriber else return nullptr; } + + case kIROp_DifferentialPairType: + { + auto primalPairType = as<IRDifferentialPairType>(primalType); + return getOrCreateDiffPairType( + pairBuilder->getDiffTypeFromPairType(builder, primalPairType), + pairBuilder->getDiffTypeWitnessFromPairType(builder, primalPairType)); + } case kIROp_FuncType: return differentiateFunctionType(builder, as<IRFuncType>(primalType)); @@ -660,7 +823,7 @@ struct JVPTranscriber return (IRType*)(differentiableTypeConformanceContext.getDifferentialForType(builder, (IRType*)primalType)); } } - + IRType* tryGetDiffPairType(IRBuilder* builder, IRType* primalType) { // If this is a PtrType (out, inout, etc..), then create diff pair from @@ -675,7 +838,7 @@ struct JVPTranscriber } auto diffType = differentiateType(builder, primalType); if (diffType) - return (IRType*)pairBuilder->getOrCreateDiffPairType(builder, primalType, diffType); + return (IRType*)getOrCreateDiffPairType(primalType); return nullptr; } @@ -692,7 +855,7 @@ struct JVPTranscriber if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)primalDataType)) { - IRParam* diffPairParam = builder->emitParam(diffPairType); + IRInst* diffPairParam = builder->emitParam(diffPairType); auto diffPairVarName = makeDiffPairName(origParam); if (diffPairVarName.getLength() > 0) @@ -700,9 +863,20 @@ struct JVPTranscriber SLANG_ASSERT(diffPairParam); - return InstPair( - pairBuilder->emitPrimalFieldAccess(builder, diffPairParam), - pairBuilder->emitDiffFieldAccess(builder, diffPairParam)); + if (auto pairType = as<IRDifferentialPairType>(diffPairParam->getDataType())) + { + return InstPair( + builder->emitDifferentialPairGetPrimal(diffPairParam), + builder->emitDifferentialPairGetDifferential( + (IRType*)pairBuilder->getDiffTypeFromPairType(builder, pairType), + diffPairParam)); + } + // If this is an `in/inout DifferentialPair<>` parameter, we can't produce + // its primal and diff parts right now because they would represent a reference + // to a pair field, which doesn't make sense since pair types are considered mutable. + // We encode the result as if the param is non-differentiable, and handle it + // with special care at load/store. + return InstPair(diffPairParam, nullptr); } @@ -826,30 +1000,52 @@ struct JVPTranscriber InstPair transcribeLoad(IRBuilder* builder, IRLoad* origLoad) { auto origPtr = origLoad->getPtr(); - - auto primalLoad = cloneInst(&cloneEnv, builder, origLoad); + auto primalPtr = lookupPrimalInst(origPtr, nullptr); + auto primalPtrValueType = as<IRPtrTypeBase>(primalPtr->getFullType())->getValueType(); - IRInst* diffLoad = nullptr; + if (auto diffPairType = as<IRDifferentialPairType>(primalPtrValueType)) + { + // Special case load from an `out` param, which will not have corresponding `diff` and + // `primal` insts yet. + auto load = builder->emitLoad(primalPtr); + auto primalElement = builder->emitDifferentialPairGetPrimal(load); + auto diffElement = builder->emitDifferentialPairGetDifferential( + (IRType*)pairBuilder->getDiffTypeFromPairType(builder, diffPairType), load); + return InstPair(primalElement, diffElement); + } + auto primalLoad = cloneInst(&cloneEnv, builder, origLoad); + IRInst* diffLoad = nullptr; if (auto diffPtr = lookupDiffInst(origPtr, nullptr)) { // Default case, we're loading from a known differential inst. diffLoad = as<IRLoad>(builder->emitLoad(diffPtr)); - return InstPair(primalLoad, diffLoad); - } - return InstPair(primalLoad, nullptr); + } + return InstPair(primalLoad, diffLoad); } InstPair transcribeStore(IRBuilder* builder, IRStore* origStore) { IRInst* origStoreLocation = origStore->getPtr(); IRInst* origStoreVal = origStore->getVal(); - - auto primalStore = cloneInst(&cloneEnv, builder, origStore); - + auto primalStoreLocation = lookupPrimalInst(origStoreLocation, nullptr); auto diffStoreLocation = lookupDiffInst(origStoreLocation, nullptr); + auto primalStoreVal = lookupPrimalInst(origStoreVal, nullptr); auto diffStoreVal = lookupDiffInst(origStoreVal, nullptr); + if (!diffStoreLocation) + { + auto primalLocationPtrType = as<IRPtrTypeBase>(primalStoreLocation->getDataType()); + if (auto diffPairType = as<IRDifferentialPairType>(primalLocationPtrType->getValueType())) + { + auto valToStore = builder->emitMakeDifferentialPair(diffPairType, primalStoreVal, diffStoreVal); + auto store = builder->emitStore(primalStoreLocation, valToStore); + return InstPair(store, nullptr); + } + } + + auto primalStore = cloneInst(&cloneEnv, builder, origStore); + IRInst* diffStore = nullptr; // If the stored value has a differential version, @@ -1052,8 +1248,9 @@ struct JVPTranscriber if (diffReturnType->getOp() != kIROp_VoidType) { - IRInst* primalResultValue = pairBuilder->emitPrimalFieldAccess(builder, callInst); - IRInst* diffResultValue = pairBuilder->emitDiffFieldAccess(builder, callInst); + IRInst* primalResultValue = builder->emitDifferentialPairGetPrimal(callInst); + auto diffType = differentiateType(builder, origCall->getFullType()); + IRInst* diffResultValue = builder->emitDifferentialPairGetDifferential(diffType, callInst); return InstPair(primalResultValue, diffResultValue); } else @@ -1174,14 +1371,16 @@ struct JVPTranscriber return InstPair(nullptr, nullptr); } - InstPair transcribeConst(IRBuilder*, IRInst* origInst) + InstPair transcribeConst(IRBuilder* builder, IRInst* origInst) { switch(origInst->getOp()) { case kIROp_FloatLit: + return InstPair(origInst, builder->getFloatValue(origInst->getDataType(), 0.0f)); case kIROp_VoidLit: + return InstPair(origInst, origInst); case kIROp_IntLit: - return InstPair(origInst, nullptr); + return InstPair(origInst, builder->getIntValue(origInst->getDataType(), 0)); } getSink()->diagnose( @@ -1245,6 +1444,14 @@ struct JVPTranscriber { if (auto diffType = differentiateType(builder, primalType)) { + switch (diffType->getOp()) + { + case kIROp_DifferentialPairType: + return builder->emitMakeDifferentialPair( + diffType, + getDifferentialZeroOfType(builder, as<IRDifferentialPairType>(diffType)->getValueType()), + getDifferentialZeroOfType(builder, as<IRDifferentialPairType>(diffType)->getValueType())); + } // Since primalType has a corresponding differential type, we can lookup the // definition for zero(). auto zeroMethod = differentiableTypeConformanceContext.getZeroMethodForType(builder, primalType); @@ -1458,40 +1665,63 @@ struct JVPTranscriber return InstPair(diffLoop, diffLoop); } - // Transcribe a function definition. - InstPair transcribeFunc(IRBuilder* builder, IRFunc* origFunc) + InstPair transcribeMakeDifferentialPair(IRBuilder* builder, IRMakeDifferentialPair* origInst) { - IRFunc* primalFunc = nullptr; + auto primalVal = findOrTranscribePrimalInst(builder, origInst->getPrimalValue()); + SLANG_ASSERT(primalVal); + auto diffPrimalVal = findOrTranscribePrimalInst(builder, origInst->getDifferentialValue()); + SLANG_ASSERT(diffPrimalVal); + auto primalDiffVal = findOrTranscribeDiffInst(builder, origInst->getPrimalValue()); + SLANG_ASSERT(primalDiffVal); + auto diffDiffVal = findOrTranscribeDiffInst(builder, origInst->getDifferentialValue()); + SLANG_ASSERT(diffDiffVal); - differentiableTypeConformanceContext.setFunc(origFunc); + auto primalPair = builder->emitMakeDifferentialPair(origInst->getDataType(), primalVal, diffPrimalVal); + auto diffPair = builder->emitMakeDifferentialPair( + differentiateType(builder, origInst->getDataType()), + primalDiffVal, + diffDiffVal); + return InstPair(primalPair, diffPair); + } - auto oldLoc = builder->getInsertLoc(); + InstPair transcribeDifferentialPairGetElement(IRBuilder* builder, IRInst* origInst) + { + SLANG_ASSERT( + origInst->getOp() == kIROp_DifferentialPairGetDifferential || + origInst->getOp() == kIROp_DifferentialPairGetPrimal); - // If this is a top-level function, there is no need to clone it - // since it is visible in all the scopes. - // Otherwise, we need to clone it in case of generic scopes. - // - // TODO(sai): Is this the correct thing to do? Can a function cloned inside a - // generic scope but is not the return value of that generic, be used within - // that scope? Or do we have to call out to the original generic specialized with - // the current generic params? - // - bool isTopLevelFunc = (as<IRModuleInst>(origFunc->parent) != nullptr); - if (isTopLevelFunc) - { - builder->setInsertBefore(origFunc); - primalFunc = origFunc; - } + auto primalVal = findOrTranscribePrimalInst(builder, origInst->getOperand(0)); + SLANG_ASSERT(primalVal); + + auto diffVal = findOrTranscribeDiffInst(builder, origInst->getOperand(0)); + SLANG_ASSERT(diffVal); + + auto primalResult = builder->emitIntrinsicInst(origInst->getFullType(), origInst->getOp(), 1, &primalVal); + + auto diffValPairType = as<IRDifferentialPairType>(diffVal->getDataType()); + IRInst* diffResultType = nullptr; + if (origInst->getOp() == kIROp_DifferentialPairGetDifferential) + diffResultType = pairBuilder->getDiffTypeFromPairType(builder, diffValPairType); else - { - // TODO(sai): this might never be called, and it might never make sense - // to call it either. Potentially remove this. - primalFunc = as<IRFunc>( - cloneInst(&cloneEnv, builder, origFunc)); - } + diffResultType = diffValPairType->getValueType(); + auto diffResult = builder->emitIntrinsicInst((IRType*)diffResultType, origInst->getOp(), 1, &diffVal); + return InstPair(primalResult, diffResult); + } + + // Create an empty func to represent the transcribed func of `origFunc`. + InstPair transcribeFuncHeader(IRBuilder* builder, IRFunc* origFunc) + { + auto oldLoc = builder->getInsertLoc(); + + IRFunc* primalFunc = origFunc; + + differentiableTypeConformanceContext.setFunc(origFunc); + + builder->setInsertBefore(origFunc); + primalFunc = origFunc; auto diffFunc = builder->createFunc(); - + SLANG_ASSERT(as<IRFuncType>(origFunc->getFullType())); IRType* diffFuncType = this->differentiateFunctionType( builder, @@ -1505,10 +1735,33 @@ struct JVPTranscriber newNameSb << "s_jvp_" << originalName; builder->addNameHintDecoration(diffFunc, newNameSb.getUnownedSlice()); } - + builder->addForwardDerivativeDecoration(origFunc, diffFunc); + + // Mark the generated derivative function itself as differentiable. + builder->addForwardDifferentiableDecoration(diffFunc); + + // Find and clone `DifferentiableTypeDictionaryDecoration` to the new diffFunc. + if (auto dictDecor = origFunc->findDecoration<IRDifferentiableTypeDictionaryDecoration>()) + { + cloneDecoration(dictDecor, diffFunc); + } + + // Reset builder position + builder->setInsertLoc(oldLoc); + auto result = InstPair(primalFunc, diffFunc); + followUpFunctionsToTranscribe.add(result); + return result; + } + + // Transcribe a function definition. + InstPair transcribeFunc(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffFunc) + { + auto oldLoc = builder->getInsertLoc(); + + differentiableTypeConformanceContext.setFunc(primalFunc); // Transcribe children from origFunc into diffFunc builder->setInsertInto(diffFunc); - for (auto block = origFunc->getFirstBlock(); block; block = block->getNextBlock()) + for (auto block = primalFunc->getFirstBlock(); block; block = block->getNextBlock()) this->transcribe(builder, block); // Reset builder position @@ -1685,6 +1938,11 @@ struct JVPTranscriber case kIROp_ifElse: return transcribeIfElse(builder, as<IRIfElse>(origInst)); + case kIROp_MakeDifferentialPair: + return transcribeMakeDifferentialPair(builder, as<IRMakeDifferentialPair>(origInst)); + case kIROp_DifferentialPairGetPrimal: + case kIROp_DifferentialPairGetDifferential: + return transcribeDifferentialPairGetElement(builder, origInst); } // If none of the cases have been hit, check if the instruction is a @@ -1722,7 +1980,7 @@ struct JVPTranscriber switch (origInst->getOp()) { case kIROp_Func: - return transcribeFunc(builder, as<IRFunc>(origInst)); + return transcribeFuncHeader(builder, as<IRFunc>(origInst)); case kIROp_Block: return transcribeBlock(builder, as<IRBlock>(origInst)); @@ -1741,45 +1999,7 @@ struct JVPTranscriber } }; -struct IRWorkQueue -{ - // Work list to hold the active set of insts whose children - // need to be looked at. - // - List<IRInst*> workList; - HashSet<IRInst*> workListSet; - - void push(IRInst* inst) - { - if(!inst) return; - if(workListSet.Contains(inst)) return; - - workList.add(inst); - workListSet.Add(inst); - } - - IRInst* pop() - { - if (workList.getCount() != 0) - { - IRInst* topItem = workList.getFirst(); - // TODO(Sai): Repeatedly calling removeAt() can be really slow. - // Consider a specialized data structure or using removeLast() - // - workList.removeAt(0); - workListSet.Remove(topItem); - return topItem; - } - return nullptr; - } - - IRInst* peek() - { - return workList.getFirst(); - } -}; - -struct JVPDerivativeContext +struct JVPDerivativeContext : public InstPassBase { DiagnosticSink* getSink() @@ -1795,6 +2015,7 @@ struct JVPDerivativeContext // SharedIRBuilder* sharedBuilder = &sharedBuilderStorage; sharedBuilder->init(module); + sharedBuilder->deduplicateAndRebuildGlobalNumberingMap(); IRBuilder builderStorage(sharedBuilderStorage); IRBuilder* builder = &builderStorage; @@ -1809,8 +2030,12 @@ struct JVPDerivativeContext // IRDifferentialPairGetPrimal with 'primal' field access, and // IRMakeDifferentialPair with an IRMakeStruct. // + modified |= simplifyDifferentialBottomType(builder); + modified |= processPairTypes(builder, module->getModuleInst()); - + + modified |= eliminateDifferentialBottomType(builder); + return modified; } @@ -1826,121 +2051,92 @@ struct JVPDerivativeContext // bool processReferencedFunctions(IRBuilder* builder) { - IRWorkQueue* workQueue = &(workQueueStorage); + List<IRForwardDifferentiate*> autoDiffWorkList; - // Put the top-level inst into the queue. - workQueue->push(module->getModuleInst()); - - // Keep processing items until the queue is complete. - while (IRInst* workItem = workQueue->pop()) - { - for(auto child = workItem->getFirstChild(); child; child = child->getNextInst()) + for (;;) + { + // Collect all `ForwardDifferentiate` insts from the module. + autoDiffWorkList.clear(); + processInstsOfType<IRForwardDifferentiate>(kIROp_ForwardDifferentiate, [&](IRForwardDifferentiate* fwdDiffInst) { - // Either the child instruction has more children (func/block etc..) - // and we add it to the work list for further processing, or - // it's an ordinary inst in which case we check if it's a ForwardDifferentiate - // instruction. - // - if (child->getFirstChild() != nullptr) - workQueue->push(child); - - if (auto jvpDiffInst = as<IRForwardDifferentiate>(child)) - { - auto baseInst = jvpDiffInst->getBaseFn(); + autoDiffWorkList.add(fwdDiffInst); + }); - IRGlobalValueWithCode* baseFunction = nullptr; + if (autoDiffWorkList.getCount() == 0) + break; - if (auto specializeInst = as<IRSpecialize>(baseInst)) - { - // Certain specialize insts come with a derivative - // reference attached. Skip such instructions. - // - if (lookupJVPReference(specializeInst)) continue; - } - else if (auto globalValWithCode = as<IRGlobalValueWithCode>(baseInst)) + // Process collected `ForwardDifferentiate` insts and replace them with placeholders for + // differentiated functions. + transcriberStorage.followUpFunctionsToTranscribe.clear(); + + for (auto fwdDiffInst : autoDiffWorkList) + { + auto baseInst = fwdDiffInst->getBaseFn(); + if (auto baseFunction = as<IRGlobalValueWithCode>(baseInst)) + { + if (auto existingDiffFunc = lookupJVPReference(baseFunction)) { - baseFunction = globalValWithCode; + fwdDiffInst->replaceUsesWith(existingDiffFunc); + fwdDiffInst->removeAndDeallocate(); } - - SLANG_ASSERT(baseFunction); - - // If the JVP Reference already exists, no need to - // differentiate again. - // - if (lookupJVPReference(baseFunction)) continue; - - if (isMarkedForForwardDifferentiation(baseFunction)) + else if (isMarkedForForwardDifferentiation(baseFunction)) { if (as<IRFunc>(baseFunction) || as<IRGeneric>(baseFunction)) { - IRInst* diffFunc = (&transcriberStorage)->transcribe(builder, baseFunction); + IRInst* diffFunc = transcriberStorage.transcribe(builder, baseFunction); SLANG_ASSERT(diffFunc); - builder->addForwardDerivativeDecoration(baseFunction, diffFunc); - workQueue->push(diffFunc); - } + fwdDiffInst->replaceUsesWith(diffFunc); + fwdDiffInst->removeAndDeallocate(); + } else { // TODO(Sai): This would probably be better with a more specific // error code. - getSink()->diagnose(jvpDiffInst->sourceLoc, + getSink()->diagnose(fwdDiffInst->sourceLoc, Diagnostics::internalCompilerError, "Unexpected instruction. Expected func or generic"); } } - else + else { // TODO(Sai): This would probably be better with a more specific // error code. - getSink()->diagnose(jvpDiffInst->sourceLoc, + getSink()->diagnose(fwdDiffInst->sourceLoc, Diagnostics::internalCompilerError, "Cannot differentiate functions not marked for differentiation"); } } } - } - - return true; - } - - IRInst* lowerPairType(IRBuilder* builder, IRType* type) - { - - if (auto pairType = as<IRDifferentialPairType>(type)) - { - builder->setInsertBefore(pairType); - - if (!as<IRType>(pairType->getValueType())) + // Actually synthesize the derivatives. + List<InstPair> followUpWorkList = _Move(transcriberStorage.followUpFunctionsToTranscribe); + for (auto task : followUpWorkList) { - return nullptr; - } - auto witness = pairType->getWitness(); - auto diffType = _lookupWitness(builder, witness, autoDiffSharedContextStorage.differentialAssocTypeStructKey); - if (!diffType) - { - return nullptr; + auto diffFunc = as<IRFunc>(task.differential); + SLANG_ASSERT(diffFunc); + auto primalFunc = as<IRFunc>(task.primal); + SLANG_ASSERT(primalFunc); + + transcriberStorage.transcribeFunc(builder, primalFunc, diffFunc); } - auto diffPairStructType = (&pairBuilderStorage)->getOrCreateDiffPairType( - builder, - pairType->getValueType(), - (IRType*)(diffType)); - pairType->replaceUsesWith(diffPairStructType); - pairType->removeAndDeallocate(); + // Transcribing the function body really shouldn't produce more follow up function body work. + // However it may produce new `ForwardDifferentiate` instructions, which we collect and process + // in the next iteration. + SLANG_RELEASE_ASSERT(transcriberStorage.followUpFunctionsToTranscribe.getCount() == 0); - return diffPairStructType; - } - else if (auto loweredStructType = as<IRStructType>(type)) - { - // Already lowered to struct. - return loweredStructType; - } - else if (auto specializedStructType = as<IRSpecialize>(type)) - { - // Already lowered to specialized struct. - return specializedStructType; } - - return nullptr; + return true; + } + + IRInst* lowerPairType(IRBuilder* builder, IRType* pairType, bool* isTrivial = nullptr) + { + builder->setInsertBefore(pairType); + auto loweredPairTypeInfo = (&pairBuilderStorage)->lowerDiffPairType( + builder, + pairType); + if (isTrivial) + *isTrivial = loweredPairTypeInfo.isTrivial; + return loweredPairTypeInfo.loweredType; } IRInst* lowerMakePair(IRBuilder* builder, IRInst* inst) @@ -1948,19 +2144,24 @@ struct JVPDerivativeContext if (auto makePairInst = as<IRMakeDifferentialPair>(inst)) { - if (auto diffPairStructType = lowerPairType(builder, makePairInst->getDataType())) + bool isTrivial = false; + auto pairType = as<IRDifferentialPairType>(makePairInst->getDataType()); + if (auto loweredPairType = lowerPairType(builder, pairType, &isTrivial)) { builder->setInsertBefore(makePairInst); - - List<IRInst*> operands; - operands.add(makePairInst->getPrimalValue()); - operands.add(makePairInst->getDifferentialValue()); - - auto makeStructInst = builder->emitMakeStruct((IRType*)(diffPairStructType), operands); - makePairInst->replaceUsesWith(makeStructInst); + IRInst* result = nullptr; + if (isTrivial) + { + result = makePairInst->getPrimalValue(); + } + else + { + IRInst* operands[2] = { makePairInst->getPrimalValue(), makePairInst->getDifferentialValue() }; + result = builder->emitMakeStruct((IRType*)(loweredPairType), 2, operands); + } + makePairInst->replaceUsesWith(result); makePairInst->removeAndDeallocate(); - - return makeStructInst; + return result; } } @@ -1971,11 +2172,11 @@ struct JVPDerivativeContext { if (auto getDiffInst = as<IRDifferentialPairGetDifferential>(inst)) { - if (lowerPairType(builder, getDiffInst->getBase()->getDataType())) + if (lowerPairType(builder, getDiffInst->getBase()->getDataType(), nullptr)) { builder->setInsertBefore(getDiffInst); - - auto diffFieldExtract = (&pairBuilderStorage)->emitDiffFieldAccess(builder, getDiffInst->getBase()); + IRInst* diffFieldExtract = nullptr; + diffFieldExtract = (&pairBuilderStorage)->emitDiffFieldAccess(builder, getDiffInst->getBase()); getDiffInst->replaceUsesWith(diffFieldExtract); getDiffInst->removeAndDeallocate(); return diffFieldExtract; @@ -1983,14 +2184,14 @@ struct JVPDerivativeContext } else if (auto getPrimalInst = as<IRDifferentialPairGetPrimal>(inst)) { - if (lowerPairType(builder, getPrimalInst->getBase()->getDataType())) + if (lowerPairType(builder, getPrimalInst->getBase()->getDataType(), nullptr)) { builder->setInsertBefore(getPrimalInst); - auto primalFieldExtract = (&pairBuilderStorage)->emitPrimalFieldAccess(builder, getPrimalInst->getBase()); + IRInst* primalFieldExtract = nullptr; + primalFieldExtract = (&pairBuilderStorage)->emitPrimalFieldAccess(builder, getPrimalInst->getBase()); getPrimalInst->replaceUsesWith(primalFieldExtract); getPrimalInst->removeAndDeallocate(); - return primalFieldExtract; } } @@ -2001,40 +2202,195 @@ struct JVPDerivativeContext bool processPairTypes(IRBuilder* builder, IRInst* instWithChildren) { bool modified = false; + // Hoist all pair types to global scope when possible. + auto moduleInst = module->getModuleInst(); + processInstsOfType<IRDifferentialPairType>(kIROp_DifferentialPairType, [&](IRInst* originalPairType) + { + if (originalPairType->parent != moduleInst) + { + originalPairType->removeFromParent(); + ShortList<IRInst*> operands; + for (UInt i = 0; i < originalPairType->getOperandCount(); i++) + { + operands.add(originalPairType->getOperand(i)); + } + auto newPairType = builder->findOrEmitHoistableInst( + originalPairType->getFullType(), + originalPairType->getOp(), + originalPairType->getOperandCount(), + operands.getArrayView().getBuffer()); + originalPairType->replaceUsesWith(newPairType); + originalPairType->removeAndDeallocate(); + } + }); - for (auto child = instWithChildren->getFirstChild(); child; ) - { - // Make sure the builder is at the right level. - builder->setInsertInto(instWithChildren); - - auto nextChild = child->getNextInst(); + sharedBuilderStorage.deduplicateAndRebuildGlobalNumberingMap(); - switch (child->getOp()) + processAllInsts([&](IRInst* inst) { - case kIROp_DifferentialPairType: - lowerPairType(builder, as<IRType>(child)); - break; - + // Make sure the builder is at the right level. + builder->setInsertInto(instWithChildren); + + switch (inst->getOp()) + { case kIROp_DifferentialPairGetDifferential: case kIROp_DifferentialPairGetPrimal: - lowerPairAccess(builder, child); + lowerPairAccess(builder, inst); + modified = true; break; - + case kIROp_MakeDifferentialPair: - lowerMakePair(builder, child); + lowerMakePair(builder, inst); + modified = true; break; - + default: - if (child->getFirstChild()) - modified = processPairTypes(builder, child) | modified; - } + break; + } + }); - child = nextChild; + processInstsOfType<IRDifferentialPairType>(kIROp_DifferentialPairType, [&](IRDifferentialPairType* inst) + { + if (auto loweredType = lowerPairType(builder, inst)) + { + inst->replaceUsesWith(loweredType); + inst->removeAndDeallocate(); + } + }); + return modified; + } + + bool simplifyDifferentialBottomType(IRBuilder* builder) + { + bool modified = false; + auto diffBottom = builder->getDifferentialBottom(); + + bool changed = true; + List<IRUse*> uses; + while (changed) + { + changed = false; + // Replace all insts whose type is `DifferentialBottomType` to `diffBottom`. + processAllInsts([&](IRInst* inst) + { + if (inst->getDataType() && inst->getDataType()->getOp() == kIROp_DifferentialBottomType) + { + if (inst != diffBottom) + { + inst->replaceUsesWith(diffBottom); + inst->removeAndDeallocate(); + modified = true; + } + } + }); + // Go through all uses of diffBottom and run simplification. + processAllInsts([&](IRInst* inst) + { + if (!inst->hasUses()) + return; + + builder->setInsertBefore(inst); + IRInst* valueToReplace = nullptr; + switch (inst->getOp()) + { + case kIROp_Store: + if (as<IRStore>(inst)->getVal() == diffBottom) + { + inst->removeAndDeallocate(); + changed = true; + } + return; + case kIROp_MakeDifferentialPair: + // Our simplification could lead to a situation where + // bottom is used to make a pair that has a non-bottom differential type, + // in this case we should use zero instead. + if (inst->getOperand(1) == diffBottom) + { + // Only apply if we are the second operand. + auto pairType = as<IRDifferentialPairType>(inst->getDataType()); + if (pairBuilderStorage.getDiffTypeFromPairType(builder, pairType)->getOp() != kIROp_DifferentialBottomType) + { + auto zero = transcriberStorage.getDifferentialZeroOfType(builder, pairType->getValueType()); + inst->setOperand(1, zero); + changed = true; + } + } + return; + case kIROp_DifferentialPairGetDifferential: + if (inst->getOperand(0)->getOp() == kIROp_MakeDifferentialPair) + { + valueToReplace = inst->getOperand(0)->getOperand(1); + } + break; + case kIROp_DifferentialPairGetPrimal: + if (inst->getOperand(0)->getOp() == kIROp_MakeDifferentialPair) + { + valueToReplace = inst->getOperand(0)->getOperand(0); + } + break; + case kIROp_Add: + if (inst->getOperand(0) == diffBottom) + { + valueToReplace = inst->getOperand(1); + } + else if (inst->getOperand(1) == diffBottom) + { + valueToReplace = inst->getOperand(0); + } + break; + case kIROp_Sub: + if (inst->getOperand(0) == diffBottom) + { + // If left is bottom, and right is not bottom, then we should return -right. + // However we can't possibly run into that case since both side of - operator + // must be at the same order of differentiation. + valueToReplace = diffBottom; + } + else if (inst->getOperand(1) == diffBottom) + { + valueToReplace = inst->getOperand(0); + } + break; + case kIROp_Mul: + case kIROp_Div: + if (inst->getOperand(0) == diffBottom) + { + valueToReplace = diffBottom; + } + else if (inst->getOperand(1) == diffBottom) + { + valueToReplace = diffBottom; + } + break; + default: + break; + } + if (valueToReplace) + { + inst->replaceUsesWith(valueToReplace); + changed = true; + } + }); + modified |= changed; } return modified; } + bool eliminateDifferentialBottomType(IRBuilder* builder) + { + simplifyDifferentialBottomType(builder); + + bool modified = false; + auto diffBottom = builder->getDifferentialBottom(); + auto diffBottomType = diffBottom->getDataType(); + diffBottom->replaceUsesWith(builder->getVoidValue()); + diffBottom->removeAndDeallocate(); + diffBottomType->replaceUsesWith(builder->getVoidType()); + + return modified; + } + // Checks decorators to see if the function should // be differentiated (kIROp_ForwardDifferentiableDecoration) // @@ -2074,27 +2430,18 @@ struct JVPDerivativeContext } JVPDerivativeContext(IRModule* module, DiagnosticSink* sink) : - module(module), + InstPassBase(module), sink(sink), autoDiffSharedContextStorage(module->getModuleInst()), - transcriberStorage(&autoDiffSharedContextStorage) + transcriberStorage(&autoDiffSharedContextStorage, &sharedBuilderStorage) { + pairBuilderStorage.sharedContext = &autoDiffSharedContextStorage; transcriberStorage.sink = sink; transcriberStorage.autoDiffSharedContext = &(autoDiffSharedContextStorage); transcriberStorage.pairBuilder = &(pairBuilderStorage); } - protected: - - // This type passes over the module and generates - // forward-mode derivative versions of functions - // that are explicitly marked for it. - // - IRModule* module; - - // Shared builder state for our derivative passes. - SharedIRBuilder sharedBuilderStorage; - +protected: // A transcriber object that handles the main job of // processing instructions while maintaining state. // @@ -2104,10 +2451,6 @@ struct JVPDerivativeContext // error messages. DiagnosticSink* sink; - // Work queue to hold a stream of instructions that need - // to be checked for references to derivative functions. - IRWorkQueue workQueueStorage; - // Context to find and manage the witness tables for types // implementing `IDifferentiable` AutoDiffSharedContext autoDiffSharedContextStorage; diff --git a/source/slang/slang-ir-inst-pass-base.h b/source/slang/slang-ir-inst-pass-base.h index 2e251e46d..b5a1f168a 100644 --- a/source/slang/slang-ir-inst-pass-base.h +++ b/source/slang/slang-ir-inst-pass-base.h @@ -25,6 +25,17 @@ namespace Slang workListSet.Add(inst); } + IRInst* pop() + { + if (workList.getCount() == 0) + return nullptr; + + IRInst* inst = workList.getLast(); + workList.removeLast(); + workListSet.Remove(inst); + return inst; + } + public: InstPassBase(IRModule* inModule) : module(inModule) @@ -40,10 +51,8 @@ namespace Slang while (workList.getCount() != 0) { - IRInst* inst = workList.getLast(); + IRInst* inst = pop(); - workList.removeLast(); - workListSet.Remove(inst); if (inst->getOp() == instOp) { f(as<InstType>(inst)); @@ -66,10 +75,7 @@ namespace Slang while (workList.getCount() != 0) { - IRInst* inst = workList.getLast(); - - workList.removeLast(); - workListSet.Remove(inst); + IRInst* inst = pop(); if (inst->getOp() == instOp) { f(as<InstType>(inst)); @@ -92,10 +98,8 @@ namespace Slang while (workList.getCount() != 0) { - IRInst* inst = workList.getLast(); + IRInst* inst = pop(); - workList.removeLast(); - workListSet.Remove(inst); f(inst); for (auto child = inst->getLastChild(); child; child = child->getPrevInst()) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 989777944..1d1e2ae69 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -603,7 +603,7 @@ struct IRDifferentiableTypeDictionaryItem : IRInst IRInst* getWitness() { return getOperand(1); } }; -struct IRDifferentiableTypeDictionaryDecoration : IRInst +struct IRDifferentiableTypeDictionaryDecoration : IRDecoration { IR_LEAF_ISA(DifferentiableTypeDictionaryDecoration) }; @@ -2301,6 +2301,7 @@ public: IRInst* getBoolValue(bool value); IRInst* getIntValue(IRType* type, IRIntegerValue value); IRInst* getFloatValue(IRType* type, IRFloatingPointValue value); + IRInst* getDifferentialBottom(); IRStringLit* getStringValue(const UnownedStringSlice& slice); IRPtrLit* _getPtrValue(void* ptr); IRPtrLit* getNullPtrValue(IRType* type); @@ -2330,6 +2331,7 @@ public: IRAnyValueType* getAnyValueType(IRIntegerValue size); IRAnyValueType* getAnyValueType(IRInst* size); IRDynamicType* getDynamicType(); + IRDifferentialBottomType* getDifferentialBottomType(); IRTupleType* getTupleType(UInt count, IRType* const* types); IRTupleType* getTupleType(List<IRType*> const& types) @@ -2388,7 +2390,7 @@ public: IRDifferentialPairType* getDifferentialPairType( IRType* valueType, - IRWitnessTable* witnessTable); + IRInst* witnessTable); IRFuncType* getFuncType( UInt paramCount, @@ -2600,6 +2602,8 @@ public: IRInst* emitGetOptionalValue(IRInst* optValue); IRInst* emitMakeOptionalValue(IRInst* optType, IRInst* value); IRInst* emitMakeOptionalNone(IRInst* optType, IRInst* defaultValue); + IRInst* emitDifferentialPairGetDifferential(IRType* diffType, IRInst* diffPair); + IRInst* emitDifferentialPairGetPrimal(IRInst* diffPair); IRInst* emitMakeVector( IRType* type, UInt argCount, diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 083ef98c5..f9686ac5b 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -1967,6 +1967,7 @@ namespace Slang return getStringSlice() == rhs->getStringSlice(); } case kIROp_VoidLit: + case kIROp_DifferentialBottomValue: { return true; } @@ -2009,6 +2010,7 @@ namespace Slang return combineHash(code, Slang::getHashCode(slice.begin(), slice.getLength())); } case kIROp_VoidLit: + case kIROp_DifferentialBottomValue: { return code; } @@ -2074,12 +2076,20 @@ namespace Slang } case kIROp_VoidLit: { - const size_t instSize = prefixSize; + const size_t instSize = prefixSize + sizeof(void*); irValue = static_cast<IRConstant*>( _createInst(instSize, keyInst.getFullType(), keyInst.getOp())); irValue->value.ptrVal = keyInst.value.ptrVal; break; } + case kIROp_DifferentialBottomValue: + { + const size_t instSize = prefixSize + sizeof(void*); + irValue = static_cast<IRConstant*>( + _createInst(instSize, keyInst.getFullType(), keyInst.getOp())); + irValue->value.ptrVal = nullptr; + break; + } case kIROp_StringLit: { const UnownedStringSlice slice = keyInst.getStringSlice(); @@ -2182,6 +2192,17 @@ namespace Slang return _findOrEmitConstant(keyInst); } + IRInst* IRBuilder::getDifferentialBottom() + { + IRType* type = getDifferentialBottomType(); + IRConstant keyInst; + memset(&keyInst, 0, sizeof(keyInst)); + keyInst.m_op = kIROp_DifferentialBottomValue; + keyInst.typeUse.usedValue = type; + keyInst.value.intVal = 0; + return (IRInst*)_findOrEmitConstant(keyInst); + } + IRStringLit* IRBuilder::getStringValue(const UnownedStringSlice& inSlice) { IRConstant keyInst; @@ -2564,6 +2585,12 @@ namespace Slang IRDynamicType* IRBuilder::getDynamicType() { return (IRDynamicType*)getType(kIROp_DynamicType); } + IRDifferentialBottomType* IRBuilder::getDifferentialBottomType() + { + return (IRDifferentialBottomType*)getType(kIROp_DifferentialBottomType); + } + + IRAssociatedType* IRBuilder::getAssociatedType(ArrayView<IRInterfaceType*> constraintTypes) { return (IRAssociatedType*)getType(kIROp_AssociatedType, @@ -2760,7 +2787,7 @@ namespace Slang IRDifferentialPairType* IRBuilder::getDifferentialPairType( IRType* valueType, - IRWitnessTable* witnessTable) + IRInst* witnessTable) { IRInst* operands[] = { valueType, witnessTable }; return (IRDifferentialPairType*)getType( @@ -3389,6 +3416,25 @@ namespace Slang return emitIntrinsicInst(type, kIROp_makeVector, argCount, args); } + IRInst* IRBuilder::emitDifferentialPairGetDifferential(IRType* diffType, IRInst* diffPair) + { + return emitIntrinsicInst( + diffType, + kIROp_DifferentialPairGetDifferential, + 1, + &diffPair); + } + + IRInst* IRBuilder::emitDifferentialPairGetPrimal(IRInst* diffPair) + { + auto valueType = as<IRDifferentialPairType>(diffPair->getDataType())->getValueType(); + return emitIntrinsicInst( + valueType, + kIROp_DifferentialPairGetPrimal, + 1, + &diffPair); + } + IRInst* IRBuilder::emitMakeMatrix( IRType* type, UInt argCount, diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index 9295ca2f5..59a61958d 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -861,6 +861,8 @@ SIMPLE_IR_TYPE(NativeStringType, StringTypeBase) SIMPLE_IR_TYPE(DynamicType, Type) +SIMPLE_IR_TYPE(DifferentialBottomType, Type) + // True if types are equal // Note compares nominal types by name alone bool isTypeEqual(IRType* a, IRType* b); diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index 980a1d0bc..78edd4deb 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -6430,6 +6430,16 @@ namespace Slang return modifier; } + static NodeBase* parseBuiltinRequirementModifier(Parser* parser, void* /*userData*/) + { + BuiltinRequirementModifier* modifier = parser->astBuilder->create<BuiltinRequirementModifier>(); + parser->ReadToken(TokenType::LParent); + modifier->kind = BuiltinRequirementKind(StringToInt(parser->ReadToken(TokenType::IntegerLiteral).getContent())); + parser->ReadToken(TokenType::RParent); + + return modifier; + } + static NodeBase* parseMagicTypeModifier(Parser* parser, void* /*userData*/) { MagicTypeModifier* modifier = parser->astBuilder->create<MagicTypeModifier>(); @@ -6618,6 +6628,8 @@ namespace Slang _makeParseModifier("__cuda_sm_version", parseCUDASMVersionModifier), _makeParseModifier("__builtin_type", parseBuiltinTypeModifier), + _makeParseModifier("__builtin_requirement", parseBuiltinRequirementModifier), + _makeParseModifier("__magic_type", parseMagicTypeModifier), _makeParseModifier("__intrinsic_type", parseIntrinsicTypeModifier), _makeParseModifier("__implicit_conversion", parseImplicitConversionModifier), diff --git a/source/slang/slang-syntax.cpp b/source/slang/slang-syntax.cpp index 8cd443438..12b9dab42 100644 --- a/source/slang/slang-syntax.cpp +++ b/source/slang/slang-syntax.cpp @@ -325,7 +325,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt // coerce to `DifferentialBottom`. if (astBuilder->getDifferentialBottomType()->equals(transitiveTypeWitness->subToMid->sup)) { - if (auto builtinAttr = requirementKey->findModifier<BuiltinRequirementAttribute>()) + if (auto builtinAttr = requirementKey->findModifier<BuiltinRequirementModifier>()) { if (builtinAttr->kind == BuiltinRequirementKind::DifferentialType) { |
