diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2022-08-05 13:19:20 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-08-05 13:19:20 -0400 |
| commit | 2db8c15c04f2aade49636e42f0adee636afb3b73 (patch) | |
| tree | 774758a9f854ddf655f6c46765a3ef8ca1950857 | |
| parent | 12a846e8facf090aaeb68fcabf55867f5eaed747 (diff) | |
Added a new differential type system and various improvements (#2343)
* Merge slang-ir-diff-jvp.cpp
* Added support and tests for other float vector types
* Added swizzle test and code to handle it (tests failing currently)
* Fixed one test, the other is still pending
* Fixed instruction cloning logic to avoid modifying original function
* Fixed an issue with custom 'pow_jvp' and added support for vector contructor
* Minor update to comments
* Fixed support for division
* Fixed an issue with uninitialized diagnostic sink
* Moved derivative processing to after mandatory inlining.
Skip instructions that don't have side-effects and aren't used by anything.
* WIP: Handling unconditional control flow and multi-block functions
* Support for unconditional multi-block functions
* Added a dead code elimination step to the derivative pass
* Changed name of 'hasNoSideEffects()'
* Refactored variable names
* Added initial IR defs for new type system
* Added necessary logic for semantic checking
* Overhauled type system to use builtin pair types and conform to the IDifferentiable interface
* Automatically replace IRDifferentiablePairType to a custom IRStructType
* Added generics handling by expanding the conformance context functionality and allowing for type parameters
* Minor fix: early return in processPairTypes()
* Minor fixes to differentiable resolution on generic types
* Added new instructions for differential pairs. Basic tests work now.
Looking into generic types.
* Adjusted most tests to the new type system. OutType and InOutType are still not properly working.
* Updated __jvp to produce both primal and differential output
* Moved autodiff related declarations to diff.meta.slang
* Refactored variable names
* Added initial IR defs for new type system
* Added necessary logic for semantic checking
* Overhauled type system to use builtin pair types and conform to the IDifferentiable interface
* Automatically replace IRDifferentiablePairType to a custom IRStructType
* Added generics handling by expanding the conformance context functionality and allowing for type parameters
* Minor fix: early return in processPairTypes()
* Minor fixes to differentiable resolution on generic types
* Added new instructions for differential pairs. Basic tests work now.
Looking into generic types.
* Adjusted most tests to the new type system. OutType and InOutType are still not properly working.
* Updated __jvp to produce both primal and differential output
* Moved autodiff related declarations to diff.meta.slang
* Removed external changes
* Cleanup the transcription logic: each case returns a pair of insts for the primal and differential computation.
30 files changed, 1256 insertions, 431 deletions
diff --git a/.gitignore b/.gitignore index 409435a79..574e7d3a1 100644 --- a/.gitignore +++ b/.gitignore @@ -46,6 +46,7 @@ tests/**/*.slang-module /source/slang/slang-generated-*.h /source/slang/hlsl.meta.slang.h /source/slang/core.meta.slang.h +/source/slang/diff.meta.slang.h prelude/*.h.cpp /source/slang/cpp.hint /source/slang/slang-value-generated.h diff --git a/build/visual-studio/run-generators/run-generators.vcxproj b/build/visual-studio/run-generators/run-generators.vcxproj index c8a99e931..e95f1ce02 100644 --- a/build/visual-studio/run-generators/run-generators.vcxproj +++ b/build/visual-studio/run-generators/run-generators.vcxproj @@ -343,6 +343,23 @@ <AdditionalInputs Condition="'$(Configuration)|$(Platform)'=='Release|x64'">../../../bin/windows-x64/release/slang-generate.exe</AdditionalInputs>
<AdditionalInputs Condition="'$(Configuration)|$(Platform)'=='Release aarch64|ARM'">../../../bin/windows-x64/release/slang-generate.exe</AdditionalInputs>
</CustomBuild>
+ <CustomBuild Include="..\..\..\source\slang\diff.meta.slang">
+ <FileType>Document</FileType>
+ <Command Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">"../../../bin/windows-x86/debug/slang-generate" %(Identity)</Command>
+ <Command Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">"../../../bin/windows-x64/debug/slang-generate" %(Identity)</Command>
+ <Command Condition="'$(Configuration)|$(Platform)'=='Debug aarch64|ARM'">"$(SolutionDir)/bin/windows-x64/debug/slang-generate" %(Identity)</Command>
+ <Command Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">"../../../bin/windows-x86/release/slang-generate" %(Identity)</Command>
+ <Command Condition="'$(Configuration)|$(Platform)'=='Release|x64'">"../../../bin/windows-x64/release/slang-generate" %(Identity)</Command>
+ <Command Condition="'$(Configuration)|$(Platform)'=='Release aarch64|ARM'">"$(SolutionDir)/bin/windows-x64/release/slang-generate" %(Identity)</Command>
+ <Outputs>../../../source/slang/diff.meta.slang.h</Outputs>
+ <Message>slang-generate %(Identity)</Message>
+ <AdditionalInputs Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">../../../bin/windows-x86/debug/slang-generate.exe</AdditionalInputs>
+ <AdditionalInputs Condition="'$(Configuration)|$(Platform)'=='Debug|x64'">../../../bin/windows-x64/debug/slang-generate.exe</AdditionalInputs>
+ <AdditionalInputs Condition="'$(Configuration)|$(Platform)'=='Debug aarch64|ARM'">../../../bin/windows-x64/debug/slang-generate.exe</AdditionalInputs>
+ <AdditionalInputs Condition="'$(Configuration)|$(Platform)'=='Release|Win32'">../../../bin/windows-x86/release/slang-generate.exe</AdditionalInputs>
+ <AdditionalInputs Condition="'$(Configuration)|$(Platform)'=='Release|x64'">../../../bin/windows-x64/release/slang-generate.exe</AdditionalInputs>
+ <AdditionalInputs Condition="'$(Configuration)|$(Platform)'=='Release aarch64|ARM'">../../../bin/windows-x64/release/slang-generate.exe</AdditionalInputs>
+ </CustomBuild>
<CustomBuild Include="..\..\..\source\slang\hlsl.meta.slang">
<FileType>Document</FileType>
<Command Condition="'$(Configuration)|$(Platform)'=='Debug|Win32'">"../../../bin/windows-x86/debug/slang-generate" %(Identity)</Command>
diff --git a/build/visual-studio/run-generators/run-generators.vcxproj.filters b/build/visual-studio/run-generators/run-generators.vcxproj.filters index 609af9779..d91507ae6 100644 --- a/build/visual-studio/run-generators/run-generators.vcxproj.filters +++ b/build/visual-studio/run-generators/run-generators.vcxproj.filters @@ -40,6 +40,9 @@ <CustomBuild Include="..\..\..\source\slang\core.meta.slang">
<Filter>Source Files</Filter>
</CustomBuild>
+ <CustomBuild Include="..\..\..\source\slang\diff.meta.slang">
+ <Filter>Source Files</Filter>
+ </CustomBuild>
<CustomBuild Include="..\..\..\source\slang\hlsl.meta.slang">
<Filter>Source Files</Filter>
</CustomBuild>
diff --git a/build/visual-studio/slang/slang.vcxproj b/build/visual-studio/slang/slang.vcxproj index f0cccb341..5f8c0bd2d 100644 --- a/build/visual-studio/slang/slang.vcxproj +++ b/build/visual-studio/slang/slang.vcxproj @@ -619,6 +619,7 @@ IF EXIST ..\..\..\external\slang-glslang\bin\windows-aarch64\release\slang-glsla </ItemGroup>
<ItemGroup>
<None Include="..\..\..\source\slang\core.meta.slang" />
+ <None Include="..\..\..\source\slang\diff.meta.slang" />
<None Include="..\..\..\source\slang\hlsl.meta.slang" />
</ItemGroup>
<ItemGroup>
diff --git a/build/visual-studio/slang/slang.vcxproj.filters b/build/visual-studio/slang/slang.vcxproj.filters index ea1fd80ea..18bebd332 100644 --- a/build/visual-studio/slang/slang.vcxproj.filters +++ b/build/visual-studio/slang/slang.vcxproj.filters @@ -946,6 +946,9 @@ <None Include="..\..\..\source\slang\core.meta.slang">
<Filter>Source Files</Filter>
</None>
+ <None Include="..\..\..\source\slang\diff.meta.slang">
+ <Filter>Source Files</Filter>
+ </None>
<None Include="..\..\..\source\slang\hlsl.meta.slang">
<Filter>Source Files</Filter>
</None>
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 862595b90..41f066486 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -71,10 +71,6 @@ syntax snorm : SNormModifier; /// syntax __extern_cpp : ExternCppModifier; -/// Modifer to mark a function for forward-mode differentiation. -/// i.e. the compiler will automatically generate a new function -/// that computes the jacobian-vector product of the original. -syntax __differentiate_jvp : JVPDerivativeModifier; /// A type that can be used as an operand for builtins [sealed] @@ -697,7 +693,6 @@ ${{{{ } }}}} - //@ public: /// Sampling state for filtered texture fetches. @@ -2321,7 +2316,3 @@ attribute_syntax [noinline] : NoInlineAttribute; __attributeTarget(StructDecl) attribute_syntax [payload] : PayloadAttribute; - -// Custom JVP Function reference -__attributeTarget(FuncDecl) -attribute_syntax [__custom_jvp(function)] : CustomJVPAttribute; diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang new file mode 100644 index 000000000..e604140ae --- /dev/null +++ b/source/slang/diff.meta.slang @@ -0,0 +1,70 @@ + +/// Modifer to mark a function for forward-mode differentiation. +/// i.e. the compiler will automatically generate a new function +/// that computes the jacobian-vector product of the original. +syntax __differentiate_jvp : JVPDerivativeModifier; + +// Custom JVP Function reference +__attributeTarget(FuncDecl) +attribute_syntax [__custom_jvp(function)] : CustomJVPAttribute; + +//@ public: + + /// Interface to denote types as differentiable. + /// Allows for user-specified differential types as + /// well as automatic generation, for when the associated type + /// hasn't been declared explicitly. +__magic_type(DifferentiableType) +interface IDifferentiable +{ + associatedtype Differential; +}; + + /// Pair type that serves to wrap the primal and + /// differential types of an arbitrary type T. +__generic<T : IDifferentiable> +__magic_type(DifferentialPairType) +__intrinsic_type($(kIROp_DifferentialPairType)) +struct __DifferentialPair +{ + + __intrinsic_op($(kIROp_MakeDifferentialPair)) + __init(T _primal, T.Differential _differential); + + __intrinsic_op($(kIROp_DifferentialPairGetDifferential)) + T.Differential d(); + + T.Differential getDifferential() + { + return d(); + } + + __intrinsic_op($(kIROp_DifferentialPairGetPrimal)) + T p(); + + T getPrimal() + { + return p(); + } +}; + +// Add extensions for the standard types +extension float : IDifferentiable +{ + typedef float Differential; +} + +extension vector<float, 3> : IDifferentiable +{ + typedef vector<float, 3> Differential; +} + +extension vector<float, 2> : IDifferentiable +{ + typedef vector<float, 2> Differential; +} + +extension vector<float, 4> : IDifferentiable +{ + typedef vector<float, 4> Differential; +} diff --git a/source/slang/slang-ast-builder.cpp b/source/slang/slang-ast-builder.cpp index 7f23f5492..868763f76 100644 --- a/source/slang/slang-ast-builder.cpp +++ b/source/slang/slang-ast-builder.cpp @@ -265,7 +265,7 @@ VectorExpressionType* ASTBuilder::getVectorType( IntVal* elementCount) { auto vectorGenericDecl = as<GenericDecl>(m_sharedASTBuilder->findMagicDecl("Vector")); - + auto vectorTypeDecl = vectorGenericDecl->inner; auto substitutions = create<GenericSubstitution>(); @@ -278,6 +278,30 @@ VectorExpressionType* ASTBuilder::getVectorType( return as<VectorExpressionType>(DeclRefType::create(this, declRef)); } +DifferentialPairType* ASTBuilder::getDifferentialPairType(Type* valueType, Witness* conformanceWitness) +{ + auto genericDecl = dynamicCast<GenericDecl>(m_sharedASTBuilder->findMagicDecl("DifferentialPairType")); + + auto typeDecl = genericDecl->inner; + + auto substitutions = create<GenericSubstitution>(); + substitutions->genericDecl = genericDecl; + substitutions->args.add(valueType); + substitutions->args.add(conformanceWitness); + + auto declRef = DeclRef<Decl>(typeDecl, substitutions); + auto rsType = DeclRefType::create(this, declRef); + + return as<DifferentialPairType>(rsType); +} + +DeclRef<InterfaceDecl> ASTBuilder::getDifferentiableInterface() +{ + DeclRef<InterfaceDecl> declRef; + declRef.decl = dynamicCast<InterfaceDecl>(m_sharedASTBuilder->findMagicDecl("DifferentiableType")); + return declRef; +} + DeclRef<Decl> ASTBuilder::getBuiltinDeclRef(const char* builtinMagicTypeName, ConstArrayView<Val*> genericArgs) { DeclRef<Decl> declRef; diff --git a/source/slang/slang-ast-builder.h b/source/slang/slang-ast-builder.h index d8a8679ee..3c0303e70 100644 --- a/source/slang/slang-ast-builder.h +++ b/source/slang/slang-ast-builder.h @@ -164,6 +164,10 @@ public: VectorExpressionType* getVectorType(Type* elementType, IntVal* elementCount); + DifferentialPairType* getDifferentialPairType(Type* valueType, Witness* conformanceWitness); + + DeclRef<InterfaceDecl> getDifferentiableInterface(); + DeclRef<Decl> getBuiltinDeclRef(const char* builtinMagicTypeName, ConstArrayView<Val*> genericArgs); Type* getAndType(Type* left, Type* right); diff --git a/source/slang/slang-ast-type.h b/source/slang/slang-ast-type.h index b82c6b182..5d4e42bfb 100644 --- a/source/slang/slang-ast-type.h +++ b/source/slang/slang-ast-type.h @@ -437,6 +437,21 @@ protected: }; +// A differential pair type, e.g., `__DifferentialPair<T>` +class DifferentialPairType : public ArithmeticExpressionType +{ + SLANG_AST_CLASS(DifferentialPairType) + + // The type of vector elements. + // As an invariant, this should be a basic type or an alias. + Type* baseType = nullptr; +}; + +class DifferentiableType : public BuiltinType +{ + SLANG_AST_CLASS(DifferentiableType) +}; + // A vector type, e.g., `vector<T,N>` class VectorExpressionType : public ArithmeticExpressionType { diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index b7f99c4e7..a787af211 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -1526,48 +1526,42 @@ namespace Slang return expr; } - // This function proceses primal params (i.e params of the inner function that is being - // differentiated) that need to be carried over to the function signature for the JVP - // function. (eg. out types can be discarded) - // - Type* primalToInputType(ASTBuilder*, Type* primalType) - { - if (auto primalOutType = as<OutType>(primalType)) - return nullptr; - else if (auto primalInOutType = as<InOutType>(primalType)) - return primalInOutType->getValueType(); - - return primalType; - } - Type* primalToJVPParamType(ASTBuilder* builder, Type* primalType) + Type* SemanticsVisitor::_toDifferentialParamType(ASTBuilder* builder, Type* primalType) { - // Only float and vector<float> types can be differentiated for now. - - if (primalType->equals(builder->getFloatType())) - return primalType; - else if (auto primalVectorType = as<VectorExpressionType>(primalType)) - { - if (auto jvpElementType = primalToJVPParamType(builder, primalVectorType->elementType)) - return builder->getVectorType(jvpElementType, primalVectorType->elementCount); - } - else if (auto primalOutType = as<OutType>(primalType)) + // Check for type modifiers like 'out' and 'inout'. We need to differentiate the + // nested type. + // + if (auto primalOutType = as<OutType>(primalType)) { - return builder->getOutType(primalToJVPParamType(builder, primalOutType->getValueType())); + return builder->getOutType(_toDifferentialParamType(builder, primalOutType->getValueType())); } else if (auto primalInOutType = as<InOutType>(primalType)) { - return builder->getInOutType(primalToJVPParamType(builder, primalInOutType->getValueType())); + return builder->getInOutType(_toDifferentialParamType(builder, primalInOutType->getValueType())); } - return nullptr; + + // Get a reference to the builtin 'IDifferentiable' interface + auto differentiableInterface = builder->getDifferentiableInterface(); + + // Check if the provided type inherits from IDifferentiable. + // If not, return the original type. + if (auto conformanceWitness = as<Witness>(tryGetInterfaceConformanceWitness(primalType, differentiableInterface))) + return builder->getDifferentialPairType(primalType, conformanceWitness); + else + return primalType; + } - Type* primalToJVPReturnType(ASTBuilder* builder, Type* primalType) + Type* SemanticsVisitor::_toJVPReturnType(ASTBuilder* builder, Type* primalType) { - if(auto jvpType = primalToJVPParamType(builder, primalType)) - return jvpType; + if (auto conformanceWitness = + as<Witness>(tryGetInterfaceConformanceWitness( + primalType, + builder->getDifferentiableInterface()))) + return builder->getDifferentialPairType(primalType, conformanceWitness); else - return builder->getVoidType(); + return primalType; } Expr* SemanticsExprVisitor::visitJVPDifferentiateExpr(JVPDifferentiateExpr* expr) @@ -1588,7 +1582,7 @@ namespace Slang // The JVP return type is float if primal return type is float // void otherwise. // - jvpType->resultType = primalToJVPReturnType(astBuilder, primalType->getResultType()); + jvpType->resultType = _toJVPReturnType(astBuilder, primalType->getResultType()); // No support for differentiating function that throw errors, for now. SLANG_ASSERT(primalType->errorType->equals(astBuilder->getBottomType())); @@ -1596,13 +1590,7 @@ namespace Slang for (UInt i = 0; i < primalType->getParamCount(); i++) { - if(auto primalInputType = primalToInputType(astBuilder, primalType->getParamType(i))) - jvpType->paramTypes.add(primalInputType); - } - - for (UInt i = 0; i < primalType->getParamCount(); i++) - { - if(auto jvpParamType = primalToJVPParamType(astBuilder, primalType->getParamType(i))) + if(auto jvpParamType = _toDifferentialParamType(astBuilder, primalType->getParamType(i))) jvpType->paramTypes.add(jvpParamType); } diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index ebc6e05d5..f2f7a6bd1 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -673,6 +673,15 @@ namespace Slang void _validateCircularVarDefinition(VarDeclBase* varDecl); bool shouldSkipChecking(Decl* decl, DeclCheckState state); + + // Auto-diff convenience functions for translating primal types to differential types. + Type* _toDifferentialParamType(ASTBuilder* builder, Type* primalType); + + // Translate a return type to the return type of a forward-mode differentiated + // function. + // + Type* _toJVPReturnType(ASTBuilder* builder, Type* primalType); + public: bool ValuesAreEqual( diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h index 5ba121742..fd338361e 100755 --- a/source/slang/slang-compiler.h +++ b/source/slang/slang-compiler.h @@ -2950,6 +2950,7 @@ namespace Slang Scope* coreLanguageScope = nullptr; Scope* hlslLanguageScope = nullptr; Scope* slangLanguageScope = nullptr; + Scope* autodiffLanguageScope = nullptr; ModuleDecl* baseModuleDecl = nullptr; List<RefPtr<Module>> stdlibModules; @@ -2981,10 +2982,12 @@ namespace Slang String slangLibraryCode; String hlslLibraryCode; String glslLibraryCode; + String autodiffLibraryCode; String getStdlibPath(); String getCoreLibraryCode(); String getHLSLLibraryCode(); + String getAutodiffLibraryCode(); RefPtr<SharedASTBuilder> m_sharedASTBuilder; diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp index fd1d0086d..554a407ee 100644 --- a/source/slang/slang-ir-diff-jvp.cpp +++ b/source/slang/slang-ir-diff-jvp.cpp @@ -9,19 +9,349 @@ namespace Slang { +template<typename P, typename D> +struct Pair +{ + P primal; + D differential; + + Pair(P primal, D differential) : primal(primal), differential(differential) + {} +}; + +typedef Pair<IRInst*, IRInst*> InstPair; + +struct DifferentiableTypeConformanceContext +{ + Dictionary<IRInst*, IRInst*> witnessTableMap; + + IRInst* inst = nullptr; + + // A reference to the builtin IDifferentiable interface type. + // We use this to look up all the other types (and type exprs) + // that conform to a base type. + // + IRInterfaceType* differentiableInterfaceType = nullptr; + + // The struct key for the 'Differential' associated type + // defined inside IDifferential. We use this to lookup the differential + // type in the conformance table associated with the concrete type. + // + IRStructKey* differentialAssocTypeStructKey = nullptr; + + // Modules that don't use differentiable types + // won't have the IDifferentiable interface type available. + // Set to false to indicate that we are uninitialized. + // + bool isInterfaceAvailable = false; + + // For handling generic blocks, we use a parent pointer to allow + // looking up types in all relevant scopes. + DifferentiableTypeConformanceContext* parent = nullptr; + + DifferentiableTypeConformanceContext(DifferentiableTypeConformanceContext* parent, IRInst* inst) : parent(parent), inst(inst) + { + if (parent) + { + differentiableInterfaceType = parent->differentiableInterfaceType; + differentialAssocTypeStructKey = parent->differentialAssocTypeStructKey; + isInterfaceAvailable = parent->isInterfaceAvailable; + } + else + { + differentiableInterfaceType = as<IRInterfaceType>(findDifferentiableInterface()); + if (differentiableInterfaceType) + { + differentialAssocTypeStructKey = findDifferentialTypeStructKey(); + + if (differentialAssocTypeStructKey) + isInterfaceAvailable = true; + } + } + + if (isInterfaceAvailable) + { + // Load all witness tables corresponding to the IDifferentiable interface. + loadWitnessTablesForInterface(differentiableInterfaceType); + } + } + + DifferentiableTypeConformanceContext(IRInst* inst) : + DifferentiableTypeConformanceContext(nullptr, inst) + {} + + // Lookup a witness table for the concreteType. One should exist if concreteType + // inherits (successfully) from IDifferentiable. + // + IRInst* lookUpConformanceForType(IRInst* type) + { + SLANG_ASSERT(isInterfaceAvailable); + + if (witnessTableMap.ContainsKey(type)) + return witnessTableMap[type]; + else if (parent) + return parent->lookUpConformanceForType(type); + else + return nullptr; + } + + // Lookup and return the 'Differential' type declared in the concrete type + // in order to conform to the IDifferentiable interface. + // Note that inside a generic block, this will be a witness table lookup instruction + // that gets resolved during the specialization pass. + // + IRInst* getDifferentialForType(IRBuilder* builder, IRType* origType) + { + SLANG_ASSERT(isInterfaceAvailable); + + if (auto conformance = lookUpConformanceForType(origType)) + { + if (auto witnessTable = as<IRWitnessTable>(conformance)) + { + for (auto entry : witnessTable->getEntries()) + { + if (entry->getRequirementKey() == differentialAssocTypeStructKey) + return as<IRType>(entry->getSatisfyingVal()); + } + } + else if (auto witnessTableParam = as<IRParam>(conformance)) + { + return builder->emitLookupInterfaceMethodInst( + builder->getTypeKind(), + witnessTableParam, + differentialAssocTypeStructKey); + } + } + + return nullptr; + } + + private: + + IRInst* findDifferentiableInterface() + { + if (auto module = as<IRModuleInst>(inst)) + { + for (auto globalInst : module->getGlobalInsts()) + { + // TODO: This seems like a particularly dangerous way to look for an interface. + // See if we can lower IDifferentiable to a separate IR inst. + // + if (globalInst->getOp() == kIROp_InterfaceType && + as<IRInterfaceType>(globalInst)->findDecoration<IRNameHintDecoration>()->getName() == "IDifferentiable") + { + return globalInst; + } + } + } + return nullptr; + } + + IRStructKey* findDifferentialTypeStructKey() + { + if (as<IRModuleInst>(inst) && differentiableInterfaceType) + { + // Assume for now that IDifferentiable has exactly one field: the 'Differential' associated type. + SLANG_ASSERT(differentiableInterfaceType->getOperandCount() == 1); + if (auto entry = as<IRInterfaceRequirementEntry>(differentiableInterfaceType->getOperand(0))) + return as<IRStructKey>(entry->getRequirementKey()); + else + { + SLANG_UNEXPECTED("IDifferentiable interface entry unexpected type"); + } + } + + return nullptr; + } + + void loadWitnessTablesForInterface(IRInst* interfaceType) + { + + if (auto module = as<IRModuleInst>(inst)) + { + for (auto globalInst : module->getGlobalInsts()) + { + if (globalInst->getOp() == kIROp_WitnessTable && + cast<IRWitnessTableType>(globalInst->getDataType())->getConformanceType() == + interfaceType) + { + // TODO: Can we have multiple conformances for the same pair of types? + // TODO: Can type instrs be duplicated (i.e. two different float types)? And if they are duplicated, can + // we supply the dictionary with a custom equality rule that uses 'type1->equals(type2)' + witnessTableMap.Add(as<IRWitnessTable>(globalInst)->getConcreteType(), globalInst); + } + } + } + else if (auto generic = as<IRGeneric>(inst)) + { + List<IRParam*> typeParams; + + auto genericParam = generic->getFirstParam(); + while (genericParam) + { + if (as<IRTypeType>(genericParam->getDataType())) + { + typeParams.add(genericParam); + } + else + break; + + genericParam = genericParam->getNextParam(); + } + + UCount tableIndex = 0; + while (genericParam) + { + SLANG_ASSERT(!as<IRTypeType>(genericParam->getDataType())); + if (auto witnessTableType = as<IRWitnessTableType>(genericParam->getDataType())) + { + if (witnessTableType->getConformanceType() == differentiableInterfaceType) + witnessTableMap.Add(typeParams[tableIndex], genericParam); + } + else + break; + + tableIndex += 1; + genericParam = genericParam->getNextParam(); + } + + } + + } + +}; + +struct DifferentialPairTypeBuilder +{ + + DifferentialPairTypeBuilder(DifferentiableTypeConformanceContext* diffConformanceContext) : + diffConformanceContext(diffConformanceContext) + {} + + IRInst* emitPrimalFieldAccess(IRBuilder* builder, IRInst* baseInst) + { + if (auto basePairStructType = as<IRStructType>(baseInst->getDataType())) + { + auto primalField = as<IRStructField>(basePairStructType->getFirstChild()); + SLANG_ASSERT(primalField); + + return as<IRFieldExtract>(builder->emitFieldExtract( + primalField->getFieldType(), + baseInst, + primalField->getKey() + )); + } + else if (auto ptrType = as<IRPtrTypeBase>(baseInst->getDataType())) + { + if (auto pairStructType = as<IRStructType>(ptrType->getValueType())) + { + auto primalField = as<IRStructField>(pairStructType->getFirstChild()); + SLANG_ASSERT(primalField); + + return as<IRFieldAddress>(builder->emitFieldAddress( + builder->getPtrType(primalField->getFieldType()), + baseInst, + primalField->getKey() + )); + } + } + else + { + SLANG_UNREACHABLE("basePairType must be an IRStructType or PtrType<IRStructType>"); + } + return nullptr; + } + + IRInst* emitDiffFieldAccess(IRBuilder* builder, IRInst* baseInst) + { + if (auto basePairStructType = as<IRStructType>(baseInst->getDataType())) + { + auto diffField = as<IRStructField>(basePairStructType->getFirstChild()->getNextInst()); + SLANG_ASSERT(diffField); + + return as<IRFieldExtract>(builder->emitFieldExtract( + diffField->getFieldType(), + baseInst, + diffField->getKey() + )); + } + else if (auto ptrType = as<IRPtrTypeBase>(baseInst->getDataType())) + { + if (auto pairStructType = as<IRStructType>(ptrType->getValueType())) + { + auto diffField = as<IRStructField>(pairStructType->getFirstChild()->getNextInst()); + SLANG_ASSERT(diffField); + + return as<IRFieldAddress>(builder->emitFieldAddress( + builder->getPtrType(diffField->getFieldType()), + baseInst, + diffField->getKey() + )); + } + } + else + { + SLANG_UNREACHABLE("basePairType must be an IRStructType or PtrType<IRStructType>"); + } + return nullptr; + } + + IRStructType* _createDiffPairType(IRBuilder* builder, IRType* origBaseType) + { + if (auto diffBaseType = diffConformanceContext->getDifferentialForType(builder, origBaseType)) + { + auto diffPairType = builder->createStructType(); + + // Create a keys for the primal and differential fields. + IRStructKey* origKey = builder->createStructKey(); + builder->addNameHintDecoration(origKey, UnownedTerminatedStringSlice("primal")); + builder->createStructField(diffPairType, origKey, origBaseType); + + IRStructKey* diffKey = builder->createStructKey(); + builder->addNameHintDecoration(diffKey, UnownedTerminatedStringSlice("differential")); + builder->createStructField(diffPairType, diffKey, (IRType*)(diffBaseType)); + + return diffPairType; + } + return nullptr; + } + + IRStructType* getOrCreateDiffPairType(IRBuilder* builder, IRType* origBaseType) + { + if (pairTypeCache.ContainsKey(origBaseType)) + return pairTypeCache[origBaseType]; + + auto pairType = _createDiffPairType(builder, origBaseType); + pairTypeCache.Add(origBaseType, pairType); + + return pairType; + } + + Dictionary<IRType*, IRStructType*> pairTypeCache; + + DifferentiableTypeConformanceContext* diffConformanceContext; + +}; + struct JVPTranscriber { // Stores the mapping of arbitrary 'R-value' instructions to instructions that represent // their differential values. - Dictionary<IRInst*, IRInst*> instMapD; + Dictionary<IRInst*, IRInst*> instMapD; // Cloning environment to hold mapping from old to new copies for the primal // instructions. - IRCloneEnv cloneEnv; + IRCloneEnv cloneEnv; // Diagnostic sink for error messages. - DiagnosticSink* sink; + DiagnosticSink* sink; + + // Type conformance information. + DifferentiableTypeConformanceContext* diffConformanceContext; + + // Builder to help with creating and accessing the 'DifferentiablePair<T>' struct + DifferentialPairTypeBuilder* pairBuilder; DiagnosticSink* getSink() { @@ -29,284 +359,318 @@ struct JVPTranscriber return sink; } - void mapDifferentialInst(IRInst* instP, IRInst* instD) + void mapDifferentialInst(IRInst* origInst, IRInst* diffInst) { - instMapD.Add(instP, instD); + instMapD.Add(origInst, diffInst); } - IRInst* getDifferentialInst(IRInst* instP) + void mapPrimalInst(IRInst* origInst, IRInst* primalInst) { - return instMapD[instP]; + if (cloneEnv.mapOldValToNew.ContainsKey(origInst) && cloneEnv.mapOldValToNew[origInst] != primalInst) + { + getSink()->diagnose(origInst->sourceLoc, + Diagnostics::internalCompilerError, + "inconsistent primal instruction for original"); + } + else + { + cloneEnv.mapOldValToNew[origInst] = primalInst; + } } - IRInst* getDifferentialInst(IRInst* instP, IRInst* defaultInst) + IRInst* lookupDiffInst(IRInst* origInst) { - return (hasDifferentialInst(instP)) ? instMapD[instP] : defaultInst; + return instMapD[origInst]; } - bool hasDifferentialInst(IRInst* instP) + IRInst* lookupDiffInst(IRInst* origInst, IRInst* defaultInst) { - return instMapD.ContainsKey(instP); + return (hasDifferentialInst(origInst)) ? instMapD[origInst] : defaultInst; } - IRFuncType* differentiateFunctionType(IRBuilder* builder, IRFuncType* funcType) + bool hasDifferentialInst(IRInst* origInst) { - List<IRType*> parameterTypesD; - IRType* returnTypeD; + return instMapD.ContainsKey(origInst); + } - // Add all primal parameters to the list. - for (UIndex i = 0; i < funcType->getParamCount(); i++) - { - // TODO(sai): Move this check to a separate function. - if (!as<IROutType>(funcType->getParamType(i))) - parameterTypesD.add(funcType->getParamType(i)); + IRInst* lookupPrimalInst(IRInst* origInst) + { + return cloneEnv.mapOldValToNew[origInst]; + } + + IRInst* lookupPrimalInst(IRInst* origInst, IRInst* defaultInst) + { + return (hasPrimalInst(origInst)) ? lookupPrimalInst(origInst) : defaultInst; + } + + bool hasPrimalInst(IRInst* origInst) + { + return cloneEnv.mapOldValToNew.ContainsKey(origInst); + } + + IRInst* findOrTranscribeDiffInst(IRBuilder* builder, IRInst* origInst) + { + if (!hasDifferentialInst(origInst)) + { + transcribe(builder, origInst); + SLANG_ASSERT(hasDifferentialInst(origInst)); + } + + return lookupDiffInst(origInst); + } + + IRInst* findOrTranscribePrimalInst(IRBuilder* builder, IRInst* origInst) + { + if (!hasPrimalInst(origInst)) + { + transcribe(builder, origInst); + SLANG_ASSERT(hasPrimalInst(origInst)); } - // Add differential versions for the types we support. + return lookupPrimalInst(origInst); + } + + IRFuncType* differentiateFunctionType(IRBuilder* builder, IRFuncType* funcType) + { + List<IRType*> newParameterTypes; + IRType* diffReturnType; + for (UIndex i = 0; i < funcType->getParamCount(); i++) - { - if (auto typeD = differentiateType(builder, funcType->getParamType(i))) - parameterTypesD.add(typeD); + { + auto origType = funcType->getParamType(i); + if (auto diffPairType = tryGetDiffPairType(builder, origType)) + newParameterTypes.add(diffPairType); + else + newParameterTypes.add(origType); } - // Transcribe return type. + // Transcribe return type to a pair. // This will be void if the primal return type is non-differentiable. // - returnTypeD = differentiateType(builder, funcType->getResultType()); - if (!returnTypeD) - returnTypeD = builder->getVoidType(); + if (auto returnPairType = tryGetDiffPairType(builder, funcType->getResultType())) + diffReturnType = returnPairType; + else + diffReturnType = builder->getVoidType(); - return builder->getFuncType(parameterTypesD, returnTypeD); + return builder->getFuncType(newParameterTypes, diffReturnType); } - IRType* differentiateType(IRBuilder* builder, IRType* typeP) + IRType* differentiateType(IRBuilder* builder, IRType* origType) { - switch (typeP->getOp()) + switch (origType->getOp()) { case kIROp_HalfType: case kIROp_FloatType: case kIROp_DoubleType: - return builder->getType(typeP->getOp()); case kIROp_VectorType: - // TODO(sai): Call differentiateType() on typeP. - return as<IRVectorType>(typeP); + return (IRType*)(diffConformanceContext->getDifferentialForType(builder, origType)); case kIROp_OutType: - return builder->getOutType(differentiateType(builder, as<IROutType>(typeP)->getValueType())); + return builder->getOutType(differentiateType(builder, as<IROutType>(origType)->getValueType())); case kIROp_InOutType: - return builder->getInOutType(differentiateType(builder, as<IRInOutType>(typeP)->getValueType())); + return builder->getInOutType(differentiateType(builder, as<IRInOutType>(origType)->getValueType())); default: return nullptr; } } - IRInst* differentiateParam(IRBuilder* builder, IRParam* paramP) - { - if (IRType* typeD = differentiateType(builder, paramP->getFullType())) - { - IRParam* paramD = builder->emitParam(typeD); - - auto nameHintD = getJVPVarName(paramP); - if (nameHintD.getLength() > 0) - builder->addNameHintDecoration(paramD, nameHintD.getUnownedSlice()); - - SLANG_ASSERT(paramD); - return paramD; - } - return nullptr; - } - - IRInst* emitInputParam(IRBuilder* builder, IRParam* paramP) + IRType* tryGetDiffPairType(IRBuilder* builder, IRType* origType) { - // Convert primal 'inout' types into pure input types, because a - // JVP transformed function must never have primal side-effects. + // If this is a PtrType (out, inout, etc..), then create diff pair from + // value type and re-apply the appropropriate PtrType wrapper. // - if (auto inoutTypeP = as<IRInOutType>(paramP->getDataType())) + if (auto origPtrType = as<IRPtrTypeBase>(origType)) { - auto newParamP = builder->emitParam(inoutTypeP->getValueType()); - cloneEnv.mapOldValToNew.Add(paramP, newParamP); - cloneInstDecorationsAndChildren(&cloneEnv, builder->getSharedBuilder(), paramP, newParamP); - - return newParamP; - } - else if (as<IROutType>(paramP->getDataType())) - { - getSink()->diagnose(paramP->sourceLoc, - Diagnostics::unexpected, - "encountered unexpected output parameter"); - return nullptr; + if (auto diffPairValueType = tryGetDiffPairType(builder, origPtrType->getValueType())) + return builder->getPtrType(origType->getOp(), diffPairValueType); + else + return nullptr; } - else - return as<IRParam>(cloneInst(&cloneEnv, builder, paramP)); + + return pairBuilder->getOrCreateDiffPairType(builder, origType); } - List<IRParam*> transcribeParams(IRBuilder* builder, IRInstList<IRParam> paramListP) + InstPair transcribeParam(IRBuilder* builder, IRParam* origParam) { - // Clone (and emit) all the primal parameters. - List<IRParam*> newParamListP; - for (auto paramP : paramListP) + if (auto diffPairType = tryGetDiffPairType(builder, origParam->getFullType())) { - if(isPurelyFunctional(builder, paramP)) - newParamListP.add(as<IRParam>(emitInputParam(builder, paramP))); + IRParam* diffPairParam = builder->emitParam(diffPairType); + + auto diffPairVarName = makeDiffPairName(origParam); + if (diffPairVarName.getLength() > 0) + builder->addNameHintDecoration(diffPairParam, diffPairVarName.getUnownedSlice()); + + SLANG_ASSERT(diffPairParam); + + return InstPair( + pairBuilder->emitPrimalFieldAccess(builder, diffPairParam), + pairBuilder->emitDiffFieldAccess(builder, diffPairParam)); } + + return InstPair( + cloneInst(&cloneEnv, builder, origParam), + nullptr); + } - // Now emit differentials. - List<IRParam*> newParamListD; - for (auto paramP : paramListP) + // Returns "d<var-name>" to use as a name hint for variables and parameters. + // If no primal name is available, returns a blank string. + // + String getJVPVarName(IRInst* origVar) + { + if (auto namehintDecoration = origVar->findDecoration<IRNameHintDecoration>()) { - IRParam* paramD = as<IRParam>(differentiateParam(builder, paramP)); - mapDifferentialInst(findCloneForOperand(&cloneEnv, paramP), paramD); - newParamListD.add(paramD); + return ("d" + String(namehintDecoration->getName())); } - return newParamListD; + return String(""); } - // Returns "d<var-name>" to use as a name hint for variables and parameters. + // Returns "dp<var-name>" to use as a name hint for parameters. // If no primal name is available, returns a blank string. // - String getJVPVarName(IRInst* varP) + String makeDiffPairName(IRInst* origVar) { - if (auto namehintDecoration = varP->findDecoration<IRNameHintDecoration>()) + if (auto namehintDecoration = origVar->findDecoration<IRNameHintDecoration>()) { - return ("d" + String(namehintDecoration->getName())); + return ("dp" + String(namehintDecoration->getName())); } return String(""); } - IRInst* differentiateVar(IRBuilder* builder, IRVar* varP) + InstPair transcribeVar(IRBuilder* builder, IRVar* origVar) { - if (IRType* typeD = differentiateType(builder, varP->getDataType()->getValueType())) + if (IRType* diffType = differentiateType(builder, origVar->getDataType()->getValueType())) { - IRVar* varD = builder->emitVar(typeD); - SLANG_ASSERT(varD); + IRVar* diffVar = builder->emitVar(diffType); + SLANG_ASSERT(diffVar); - auto nameHintD = getJVPVarName(varP); - if (nameHintD.getLength() > 0) - builder->addNameHintDecoration(varD, nameHintD.getUnownedSlice()); + auto diffNameHint = getJVPVarName(origVar); + if (diffNameHint.getLength() > 0) + builder->addNameHintDecoration(diffVar, diffNameHint.getUnownedSlice()); - return varD; + return InstPair(cloneInst(&cloneEnv, builder, origVar), diffVar); } - return nullptr; + + return InstPair(cloneInst(&cloneEnv, builder, origVar), nullptr); } - IRInst* differentiateBinaryArith(IRBuilder* builder, IRInst* arith) + InstPair transcribeBinaryArith(IRBuilder* builder, IRInst* origArith) { - SLANG_ASSERT(arith->getOperandCount() == 2); + SLANG_ASSERT(origArith->getOperandCount() == 2); + + IRInst* primalArith = cloneInst(&cloneEnv, builder, origArith); - auto leftP = arith->getOperand(0); - auto rightP = arith->getOperand(1); + auto origLeft = origArith->getOperand(0); + auto origRight = origArith->getOperand(1); - auto leftD = getDifferentialInst(leftP); - auto rightD = getDifferentialInst(rightP); + auto primalLeft = findOrTranscribePrimalInst(builder, origLeft); + auto primalRight = findOrTranscribePrimalInst(builder, origRight); + + auto diffLeft = findOrTranscribeDiffInst(builder, origLeft); + auto diffRight = findOrTranscribeDiffInst(builder, origRight); - auto leftZero = builder->getFloatValue(leftP->getDataType(), 0.0); - auto rightZero = builder->getFloatValue(rightP->getDataType(), 0.0); + auto leftZero = builder->getFloatValue(origLeft->getDataType(), 0.0); + auto rightZero = builder->getFloatValue(origRight->getDataType(), 0.0); - if (leftD || rightD) + if (diffLeft || diffRight) { - leftD = leftD ? leftD : leftZero; - rightD = rightD ? rightD : rightZero; + diffLeft = diffLeft ? diffLeft : leftZero; + diffRight = diffRight ? diffRight : rightZero; - // Might have to do special-case handling for non-scalar types, - // like float3 or float3x3 - // - auto resultType = arith->getDataType(); - switch(arith->getOp()) + auto resultType = origArith->getDataType(); + switch(origArith->getOp()) { case kIROp_Add: - return builder->emitAdd(resultType, leftD, rightD); + return InstPair(primalArith, builder->emitAdd(resultType, diffLeft, diffRight)); case kIROp_Mul: - return builder->emitAdd(resultType, - builder->emitMul(resultType, leftD, rightP), - builder->emitMul(resultType, leftP, rightD)); + return InstPair(primalArith, builder->emitAdd(resultType, + builder->emitMul(resultType, diffLeft, primalRight), + builder->emitMul(resultType, primalLeft, diffRight))); case kIROp_Sub: - return builder->emitSub(resultType, leftD, rightD); + return InstPair(primalArith, builder->emitSub(resultType, diffLeft, diffRight)); case kIROp_Div: - return builder->emitDiv(resultType, + return InstPair(primalArith, builder->emitDiv(resultType, builder->emitSub( resultType, - builder->emitMul(resultType, leftD, rightP), - builder->emitMul(resultType, leftP, rightD)), + builder->emitMul(resultType, diffLeft, primalRight), + builder->emitMul(resultType, primalLeft, diffRight)), builder->emitMul( - rightP->getDataType(), rightP, rightP - )); + primalRight->getDataType(), primalRight, primalRight + ))); default: - getSink()->diagnose(arith->sourceLoc, + getSink()->diagnose(origArith->sourceLoc, Diagnostics::unimplemented, "this arithmetic instruction cannot be differentiated"); } } - return nullptr; + return InstPair(primalArith, nullptr); } - IRInst* differentiateLoad(IRBuilder* builder, IRLoad* loadP) + InstPair transcribeLoad(IRBuilder* builder, IRLoad* origLoad) { - auto ptrP = loadP->getPtr(); - if (as<IRVar>(ptrP) || as<IRParam>(ptrP)) - { - // If the loaded parameter has a differential version, - // emit a load instruction for the differential parameter. - // Otherwise, emit nothing since there's nothing to load. - // - if (auto ptrD = getDifferentialInst(ptrP, nullptr)) - { - IRLoad* loadD = as<IRLoad>(builder->emitLoad(ptrD)); - SLANG_ASSERT(loadD); - return loadD; - } - return nullptr; + auto origPtr = origLoad->getPtr(); + + auto primalLoad = cloneInst(&cloneEnv, builder, origLoad); + + if (auto diffPtr = lookupDiffInst(origPtr, nullptr)) + { + IRLoad* diffLoad = as<IRLoad>(builder->emitLoad(diffPtr)); + SLANG_ASSERT(diffLoad); + + return InstPair(primalLoad, diffLoad); } - else - getSink()->diagnose(loadP->sourceLoc, - Diagnostics::unimplemented, - "this load instruction cannot be differentiated"); - return nullptr; + return InstPair(primalLoad, nullptr); } - IRInst* differentiateStore(IRBuilder* builder, IRStore* storeP) + InstPair transcribeStore(IRBuilder* builder, IRStore* origStore) { - IRInst* storeLocation = storeP->getPtr(); - IRInst* storeVal = storeP->getVal(); - if (as<IRVar>(storeLocation) || as<IRParam>(storeLocation)) - { - // If the stored value has a differential version, - // emit a store instruction for the differential parameter. - // Otherwise, emit nothing since there's nothing to load. - // - IRInst* storeValD = getDifferentialInst(storeVal); - IRInst* storeLocationD = getDifferentialInst(storeLocation); - if (storeValD && storeLocationD) - { - IRStore* storeD = as<IRStore>( - builder->emitStore(storeLocationD, storeValD)); - SLANG_ASSERT(storeD); - return storeD; - } - return nullptr; + IRInst* origStoreLocation = origStore->getPtr(); + IRInst* origStoreVal = origStore->getVal(); + + auto primalStore = cloneInst(&cloneEnv, builder, origStore); + + auto diffStoreLocation = lookupDiffInst(origStoreLocation, nullptr); + auto diffStoreVal = lookupDiffInst(origStoreVal, nullptr); + + // If the stored value has a differential version, + // emit a store instruction for the differential parameter. + // Otherwise, emit nothing since there's nothing to load. + // + if (diffStoreLocation && diffStoreVal) + { + IRStore* diffStore = as<IRStore>( + builder->emitStore(diffStoreLocation, diffStoreVal)); + SLANG_ASSERT(diffStore); + + return InstPair(primalStore, diffStore); } - else - getSink()->diagnose(storeP->sourceLoc, - Diagnostics::unimplemented, - "this store instruction cannot be differentiated"); - return nullptr; + + return InstPair(primalStore, nullptr); } - IRInst* differentiateReturn(IRBuilder* builder, IRReturn* returnP) + InstPair transcribeReturn(IRBuilder* builder, IRReturn* origReturn) { - IRInst* returnVal = returnP->getVal(); - if (auto returnValD = getDifferentialInst(returnVal, nullptr)) + IRInst* origReturnVal = origReturn->getVal(); + + if (auto pairType = tryGetDiffPairType(builder, origReturnVal->getDataType())) { - IRReturn* returnD = as<IRReturn>(builder->emitReturn(returnValD)); - SLANG_ASSERT(returnD); - return returnD; + IRInst* primalReturnVal = findOrTranscribePrimalInst(builder, origReturnVal); + + IRInst* diffReturnVal = findOrTranscribeDiffInst(builder, origReturnVal); + if(!diffReturnVal) + diffReturnVal = getZeroOfType(builder, origReturnVal->getDataType()); + + auto diffPair = builder->emitMakeDifferentialPair(pairType, primalReturnVal, diffReturnVal); + IRReturn* pairReturn = as<IRReturn>(builder->emitReturn(diffPair)); + return InstPair(pairReturn, pairReturn); } else { // If the differential return value is not available, emit a // void return. - return builder->emitReturn(); + IRInst* voidReturn = builder->emitReturn(); + return InstPair(voidReturn, voidReturn); } } @@ -314,68 +678,72 @@ struct JVPTranscriber // instruction, we check to make sure that the nested instr is a constant // and then return nullptr. Literals do not need to be differentiated. // - IRInst* differentiateConstruct(IRBuilder*, IRInst* consP) + InstPair transcribeConstruct(IRBuilder* builder, IRInst* origConstruct) { - if (as<IRConstant>(consP->getOperand(0)) && consP->getOperandCount() == 1) - return nullptr; + IRInst* primalConstruct = cloneInst(&cloneEnv, builder, origConstruct); + + if (as<IRConstant>(origConstruct->getOperand(0)) && origConstruct->getOperandCount() == 1) + return InstPair(primalConstruct, nullptr); else - getSink()->diagnose(consP->sourceLoc, + getSink()->diagnose(origConstruct->sourceLoc, Diagnostics::unimplemented, "this construct instruction cannot be differentiated"); - return nullptr; + + return InstPair(primalConstruct, nullptr); } // Differentiating a call instruction here is primarily about generating // an appropriate call list based on whichever parameters have differentials // in the current transcription context. - // Note(sai): Currently we don't look at modifiers (in, out, const etc..) in the function - // type, and so only support 'plain' parameters. We need to validte this somewhere to - // avoid weird behaviour // - IRInst* differentiateCall(IRBuilder* builder, IRCall* callP) + InstPair transcribeCall(IRBuilder* builder, IRCall* origCall) { - if (auto calleeP = as<IRFunc>(callP->getCallee())) + if (auto origCallee = as<IRFunc>(origCall->getCallee())) { // Build the differential callee - IRInst* calleeD = builder->emitJVPDifferentiateInst( - differentiateFunctionType(builder, as<IRFuncType>(calleeP->getFullType())), - calleeP); + IRInst* diffCall = builder->emitJVPDifferentiateInst( + differentiateFunctionType(builder, as<IRFuncType>(origCallee->getFullType())), + origCallee); List<IRInst*> args; - // Go over the parameter list and all primal arguments. - for (UIndex ii = 0; ii < callP->getArgCount(); ii++) + // Go over the parameter list and create pairs for each input (if required) + for (UIndex ii = 0; ii < origCall->getArgCount(); ii++) { - args.add(callP->getArg(ii)); - } + auto origArg = origCall->getArg(ii); + auto primalArg = findOrTranscribePrimalInst(builder, origArg); + SLANG_ASSERT(primalArg); - { - IRParam* param = calleeP->getFirstParam(); - // Go over the parameter list again and arguments for types that need differentials. - for (UIndex ii = 0; ii < callP->getArgCount(); ii++) + auto origType = origArg->getDataType(); + if (auto pairType = tryGetDiffPairType(builder, origType)) { - // Look the parameter up in the callee's signature. If it requires a derivative, proceed. - // Otherwise, continue. - // - if (differentiateType(builder, param->getDataType())) - { - // If the corresponding argument does not have a differential, create and place a - // 0 argument. - // - auto argP = callP->getArg(ii); - if (auto argD = getDifferentialInst(argP, nullptr)) - args.add(argD); - else - args.add(getZeroOfType(builder, argP->getDataType())); - } + + auto diffArg = findOrTranscribeDiffInst(builder, origArg); - param = param->getNextParam(); + // TODO(sai): This part is flawed. Replace with a call to the + // 'zero()' interface method. + if (!diffArg) + diffArg = getZeroOfType(builder, origType); + + auto diffPair = builder->emitMakeDifferentialPair(pairType, primalArg, diffArg); + + args.add(diffPair); + } + else + { + // Add original/primal argument. + args.add(primalArg); } } - return builder->emitCallInst(differentiateType(builder, callP->getFullType()), - calleeD, - args); + auto callInst = builder->emitCallInst( + tryGetDiffPairType(builder, origCall->getFullType()), + diffCall, + args); + + return InstPair( + pairBuilder->emitPrimalFieldAccess(builder, callInst), + pairBuilder->emitDiffFieldAccess(builder, callInst)); } else { @@ -384,31 +752,40 @@ struct JVPTranscriber // differentiate such calls safely. // TODO(sai): Should probably get checked in the front-end. // - getSink()->diagnose(callP->sourceLoc, + getSink()->diagnose(origCall->sourceLoc, Diagnostics::internalCompilerError, "attempting to differentiate unresolved callee"); } - return nullptr; + + return InstPair(nullptr, nullptr); } - IRInst* differentiateSwizzle(IRBuilder* builder, IRSwizzle* swizzleP) + InstPair transcribeSwizzle(IRBuilder* builder, IRSwizzle* origSwizzle) { - if (auto baseD = getDifferentialInst(swizzleP->getBase(), nullptr)) + IRInst* primalSwizzle = cloneInst(&cloneEnv, builder, origSwizzle); + + if (auto diffBase = lookupDiffInst(origSwizzle->getBase(), nullptr)) { List<IRInst*> swizzleIndices; - for (UIndex ii = 0; ii < swizzleP->getElementCount(); ii++) - swizzleIndices.add(swizzleP->getElementIndex(ii)); + for (UIndex ii = 0; ii < origSwizzle->getElementCount(); ii++) + swizzleIndices.add(origSwizzle->getElementIndex(ii)); - return builder->emitSwizzle(differentiateType(builder, swizzleP->getDataType()), - baseD, - swizzleP->getElementCount(), - swizzleIndices.getBuffer()); + return InstPair( + primalSwizzle, + builder->emitSwizzle( + differentiateType(builder, origSwizzle->getDataType()), + diffBase, + origSwizzle->getElementCount(), + swizzleIndices.getBuffer())); } - return nullptr; + + return InstPair(primalSwizzle, nullptr); } - IRInst* differentiateByPassthrough(IRBuilder* builder, IRInst* origInst) + InstPair transcribeByPassthrough(IRBuilder* builder, IRInst* origInst) { + IRInst* primalInst = cloneInst(&cloneEnv, builder, origInst); + UCount operandCount = origInst->getOperandCount(); List<IRInst*> diffOperands; @@ -419,20 +796,22 @@ struct JVPTranscriber // Otherwise, abandon the differentiation attempt and assume that origInst // cannot (or does not need to) be differentiated. // - if (auto diffInst = getDifferentialInst(origInst->getOperand(ii), nullptr)) + if (auto diffInst = lookupDiffInst(origInst->getOperand(ii), nullptr)) diffOperands.add(diffInst); else - return nullptr; + return InstPair(primalInst, nullptr); } - return builder->emitIntrinsicInst( - differentiateType(builder, origInst->getDataType()), - origInst->getOp(), - operandCount, - diffOperands.getBuffer()); + return InstPair( + primalInst, + builder->emitIntrinsicInst( + differentiateType(builder, origInst->getDataType()), + origInst->getOp(), + operandCount, + diffOperands.getBuffer())); } - IRInst* handleControlFlow(IRBuilder* builder, IRInst* origInst) + InstPair transcribeControlFlow(IRBuilder* builder, IRInst* origInst) { switch(origInst->getOp()) { @@ -443,17 +822,41 @@ struct JVPTranscriber if (origBranch->getOperandCount() > 1) break; - if (auto diffBlock = getDifferentialInst(origBranch->getTargetBlock(), nullptr)) - return builder->emitBranch(as<IRBlock>(diffBlock)); - else - return nullptr; + IRInst* diffBranch = nullptr; + + if (auto diffBlock = lookupDiffInst(origBranch->getTargetBlock(), nullptr)) + diffBranch = builder->emitBranch(as<IRBlock>(diffBlock)); + + // For now, every block in the original fn must have a corresponding + // block to compute both primals and derivatives. + SLANG_ASSERT(diffBranch); + + return InstPair(diffBranch, diffBranch); } getSink()->diagnose( origInst->sourceLoc, Diagnostics::unimplemented, "attempting to differentiate unhandled control flow"); - return nullptr; + + return InstPair(nullptr, nullptr); + } + + + InstPair transcribeConst(IRBuilder*, IRInst* origInst) + { + switch(origInst->getOp()) + { + case kIROp_FloatLit: + return InstPair(origInst, nullptr); + } + + getSink()->diagnose( + origInst->sourceLoc, + Diagnostics::unimplemented, + "attempting to differentiate unhandled const type"); + + return InstPair(nullptr, nullptr); } // In differential computation, the 'default' differential value is always zero. @@ -470,6 +873,15 @@ struct JVPTranscriber return builder->getFloatValue(type, 0.0); case kIROp_IntType: return builder->getIntValue(type, 0); + case kIROp_VectorType: + { + IRInst* args[] = {getZeroOfType(builder, as<IRVectorType>(type)->getElementType())}; + return builder->emitIntrinsicInst( + type, + kIROp_constructVectorFromScalar, + 1, + args); + } default: getSink()->diagnose(type->sourceLoc, Diagnostics::internalCompilerError, @@ -478,126 +890,85 @@ struct JVPTranscriber } } - // Logic for whether a primal instruction needs to be replicated - // in the differential function. We detect and avoid replicating - // 'side-effect' instructions. - // - bool isPurelyFunctional(IRBuilder*, IRInst* instP) + IRInst* transcribe(IRBuilder* builder, IRInst* origInst) { - if (as<IRTerminatorInst>(instP)) - return false; - else if (auto paramP = as<IRParam>(instP)) - { - // Out-type parameters are discarded from the parameter list, - // since pure JVP functions to not write to primal outputs. - // - if (as<IROutType>(paramP->getDataType())) - return false; - } - else if (auto storeP = as<IRStore>(instP)) - { - IRInst* storeLocation = storeP->getPtr(); + InstPair pair = transcribeInst(builder, origInst); - // Writing to a parameter is a side-effect that should be avoided. - if(as<IRParam>(storeLocation)) - return false; - - // If attempting to store to a location without a clone, - // then this instruction likely has side-effects external to the - // current function. - // - if(!lookUp(&cloneEnv, storeLocation)) - return false; - } - - return true; - } - - IRInst* transcribe(IRBuilder* builder, IRInst* oldInstP) - { - - // Clone the old instruction into the new differential function. - // - IRInst* instP = cloneInst(&cloneEnv, builder, oldInstP); - - SLANG_ASSERT(instP); - - IRInst* instD = differentiateInst(builder, instP); - - // In case it's not safe to clone the old instruction, - // remove it from the graph. - // For instance, instructions that handle control flow - // (return statements) shouldn't be replicated. - // - if (isPurelyFunctional(builder, oldInstP)) - mapDifferentialInst(instP, instD); - else + if (auto primalInst = pair.primal) { - // This inst should never have been used. - SLANG_ASSERT(instP->firstUse == nullptr); + mapPrimalInst(origInst, pair.primal); - instP->removeAndDeallocate(); - mapDifferentialInst(oldInstP, instD); + mapDifferentialInst(origInst, pair.differential); + return pair.differential; } - return instD; + getSink()->diagnose(origInst->sourceLoc, + Diagnostics::internalCompilerError, + "failed to transcibe instruction"); + return nullptr; } - IRInst* differentiateInst(IRBuilder* builder, IRInst* instP) + InstPair transcribeInst(IRBuilder* builder, IRInst* origInst) { // Handle common operations - switch (instP->getOp()) + switch (origInst->getOp()) { + case kIROp_Param: + return transcribeParam(builder, as<IRParam>(origInst)); + case kIROp_Var: - return differentiateVar(builder, as<IRVar>(instP)); + return transcribeVar(builder, as<IRVar>(origInst)); case kIROp_Load: - return differentiateLoad(builder, as<IRLoad>(instP)); + return transcribeLoad(builder, as<IRLoad>(origInst)); case kIROp_Store: - return differentiateStore(builder, as<IRStore>(instP)); + return transcribeStore(builder, as<IRStore>(origInst)); case kIROp_Return: - return differentiateReturn(builder, as<IRReturn>(instP)); + return transcribeReturn(builder, as<IRReturn>(origInst)); case kIROp_Add: case kIROp_Mul: case kIROp_Sub: case kIROp_Div: - return differentiateBinaryArith(builder, instP); + return transcribeBinaryArith(builder, origInst); case kIROp_Construct: - return differentiateConstruct(builder, instP); + return transcribeConstruct(builder, origInst); case kIROp_Call: - return differentiateCall(builder, as<IRCall>(instP)); + return transcribeCall(builder, as<IRCall>(origInst)); case kIROp_swizzle: - return differentiateSwizzle(builder, as<IRSwizzle>(instP)); + return transcribeSwizzle(builder, as<IRSwizzle>(origInst)); case kIROp_constructVectorFromScalar: - return differentiateByPassthrough(builder, instP); + return transcribeByPassthrough(builder, origInst); case kIROp_unconditionalBranch: case kIROp_conditionalBranch: - return handleControlFlow(builder, instP); + return transcribeControlFlow(builder, origInst); + + case kIROp_FloatLit: + return transcribeConst(builder, origInst); } // If none of the cases have been hit, check if the instruction is a // type. // For now we don't have logic to differentiate types that appear in blocks. - // So, we ignore them. + // So, we clone and avoid differentiating them. // - if (as<IRType>(instP)) - return nullptr; - + if (auto origType = as<IRType>(origInst)) + return InstPair(cloneInst(&cloneEnv, builder, origType), nullptr); // If we reach this statement, the instruction type is likely unhandled. - getSink()->diagnose(instP->sourceLoc, + getSink()->diagnose(origInst->sourceLoc, Diagnostics::unimplemented, "this instruction cannot be differentiated"); - return nullptr; + + return InstPair(nullptr, nullptr); } }; @@ -659,8 +1030,19 @@ struct JVPDerivativeContext IRBuilder builderStorage(sharedBuilderStorage); IRBuilder* builder = &builderStorage; - // processMarkedGlobalFunctions(builder); - return processReferencedFunctions(builder); + // Process all JVPDifferentiate instructions (kIROp_JVPDifferentiate), by + // generating derivative code for the referenced function. + // + bool modified = processReferencedFunctions(builder); + + // Replaces IRDifferentialPairType with an auto-generated struct, + // IRDifferentialPairGetDifferential with 'differential' field access, + // IRDifferentialPairGetPrimal with 'primal' field access, and + // IRMakeDifferentialPair with an IRMakeStruct. + // + modified |= processPairTypes(builder, module->getModuleInst(), (&diffConformanceContextStorage)); + + return modified; } IRInst* lookupJVPReference(IRInst* primalFunction) @@ -683,7 +1065,7 @@ struct JVPDerivativeContext // Keep processing items until the queue is complete. while (IRInst* workItem = workQueue->pop()) - { + { for(auto child = workItem->getFirstChild(); child; child = child->getNextInst()) { // Either the child instruction has more children (func/block etc..) @@ -749,6 +1131,132 @@ struct JVPDerivativeContext return true; } + IRInst* lowerPairType(IRBuilder* builder, IRType* type, DifferentiableTypeConformanceContext* diffContext) + { + if (diffContext->isInterfaceAvailable) + { + if (auto pairType = as<IRDifferentialPairType>(type)) + { + builder->setInsertBefore(pairType); + + auto diffPairStructType = (&pairBuilderStorage)->getOrCreateDiffPairType( + builder, + pairType->getValueType()); + + pairType->replaceUsesWith(diffPairStructType); + pairType->removeAndDeallocate(); + + return diffPairStructType; + } + else if (auto loweredStructType = as<IRStructType>(type)) + { + // Already lowered to struct. + return loweredStructType; + } + } + return nullptr; + } + + IRInst* lowerMakePair(IRBuilder* builder, IRInst* inst, DifferentiableTypeConformanceContext* diffContext) + { + + if (auto makePairInst = as<IRMakeDifferentialPair>(inst)) + { + auto diffPairStructType = lowerPairType(builder, makePairInst->getDataType(), diffContext); + + builder->setInsertBefore(makePairInst); + + List<IRInst*> operands; + operands.add(makePairInst->getPrimalValue()); + operands.add(makePairInst->getDifferentialValue()); + + auto makeStructInst = builder->emitMakeStruct(as<IRStructType>(diffPairStructType), operands); + makePairInst->replaceUsesWith(makeStructInst); + makePairInst->removeAndDeallocate(); + + return makeStructInst; + } + + return nullptr; + } + + IRInst* lowerPairAccess(IRBuilder* builder, IRInst* inst, DifferentiableTypeConformanceContext* diffContext) + { + + if (auto getDiffInst = as<IRDifferentialPairGetDifferential>(inst)) + { + lowerPairType(builder, getDiffInst->getBase()->getDataType(), diffContext); + + builder->setInsertBefore(getDiffInst); + + auto diffFieldExtract = (&pairBuilderStorage)->emitDiffFieldAccess(builder, getDiffInst->getBase()); + getDiffInst->replaceUsesWith(diffFieldExtract); + getDiffInst->removeAndDeallocate(); + + return diffFieldExtract; + } + else if (auto getPrimalInst = as<IRDifferentialPairGetPrimal>(inst)) + { + lowerPairType(builder, getPrimalInst->getBase()->getDataType(), diffContext); + + builder->setInsertBefore(getPrimalInst); + + auto primalFieldExtract = (&pairBuilderStorage)->emitPrimalFieldAccess(builder, getPrimalInst->getBase()); + getPrimalInst->replaceUsesWith(primalFieldExtract); + getPrimalInst->removeAndDeallocate(); + + return primalFieldExtract; + } + + return nullptr; + } + + bool processPairTypes(IRBuilder* builder, IRInst* instWithChildren, DifferentiableTypeConformanceContext* diffContext) + { + bool modified = false; + + // Create a new sub-context to scan witness tables inside workItem + // (mainly relevant if instWithChildren is a generic scope) + // + auto subContext = DifferentiableTypeConformanceContext(diffContext, instWithChildren); + (&pairBuilderStorage)->diffConformanceContext = (&subContext); + + for (auto child = instWithChildren->getFirstChild(); child; ) + { + // Make sure the builder is at the right level. + builder->setInsertInto(instWithChildren); + + auto nextChild = child->getNextInst(); + + switch (child->getOp()) + { + case kIROp_DifferentialPairType: + lowerPairType(builder, as<IRType>(child), &subContext); + break; + + case kIROp_DifferentialPairGetDifferential: + case kIROp_DifferentialPairGetPrimal: + lowerPairAccess(builder, child, &subContext); + break; + + case kIROp_MakeDifferentialPair: + lowerMakePair(builder, child, &subContext); + break; + + default: + if (child->getFirstChild()) + modified = processPairTypes(builder, child, (&subContext)) | modified; + } + + child = nextChild; + } + + // Reset the context back to the parent. + (&pairBuilderStorage)->diffConformanceContext = diffContext; + + return modified; + } + // Checks decorators to see if the function should // be differentiated (kIROp_JVPDerivativeMarkerDecoration) // @@ -823,12 +1331,13 @@ struct JVPDerivativeContext { auto jvpBlock = builder->emitBlock(); transcriberStorage.mapDifferentialInst(block, jvpBlock); + transcriberStorage.mapPrimalInst(block, jvpBlock); } // Go back over the blocks, and process the children of each block. for (auto block = primalFn->getFirstBlock(); block; block = block->getNextBlock()) { - auto jvpBlock = as<IRBlock>(transcriberStorage.getDifferentialInst(block, block)); + auto jvpBlock = as<IRBlock>(transcriberStorage.lookupDiffInst(block, block)); SLANG_ASSERT(jvpBlock); emitJVPBlock(builder, block, jvpBlock); } @@ -858,7 +1367,7 @@ struct JVPDerivativeContext } IRBlock* emitJVPBlock(IRBuilder* builder, - IRBlock* primalBlock, + IRBlock* origBlock, IRBlock* jvpBlock = nullptr) { JVPTranscriber* transcriber = &(transcriberStorage); @@ -869,16 +1378,17 @@ struct JVPDerivativeContext else builder->setInsertInto(jvpBlock); - // First transcribe the parameter list. This is done separately because we - // want all the derivative parameters emitted after the primal parameters - // rather than interleaved with one another. - // - transcriber->transcribeParams(builder, primalBlock->getParams()); + + // First transcribe every parameter in the block. + for (auto param = origBlock->getFirstParam(); param; param = param->getNextParam()) + { + transcriber->transcribe(builder, param); + } - // Run through every instruction and use the transcriber to generate the appropriate + // Then, run through every instruction and use the transcriber to generate the appropriate // derivative code. // - for(auto child = primalBlock->getFirstOrdinaryInst(); child; child = child->getNextInst()) + for (auto child = origBlock->getFirstOrdinaryInst(); child; child = child->getNextInst()) { transcriber->transcribe(builder, child); } @@ -886,9 +1396,14 @@ struct JVPDerivativeContext return jvpBlock; } - JVPDerivativeContext(IRModule* module, DiagnosticSink* sink) : module(module), sink(sink) + JVPDerivativeContext(IRModule* module, DiagnosticSink* sink) : + module(module), sink(sink), + diffConformanceContextStorage(module->getModuleInst()), + pairBuilderStorage(&diffConformanceContextStorage) { transcriberStorage.sink = sink; + transcriberStorage.diffConformanceContext = &(diffConformanceContextStorage); + transcriberStorage.pairBuilder = &(pairBuilderStorage); } protected: @@ -913,7 +1428,14 @@ struct JVPDerivativeContext // Work queue to hold a stream of instructions that need // to be checked for references to derivative functions. - IRWorkQueue workQueueStorage; + IRWorkQueue workQueueStorage; + + // Context to find and manage the witness tables for types + // implementing `IDifferentiable` + DifferentiableTypeConformanceContext diffConformanceContextStorage; + + // Builder for dealing with differential pair types. + DifferentialPairTypeBuilder pairBuilderStorage; }; @@ -923,15 +1445,15 @@ bool processJVPDerivativeMarkers( IRModule* module, DiagnosticSink* sink, IRJVPDerivativePassOptions const&) -{ - JVPDerivativeContext context(module, sink); - +{ // Simplify module to remove dead code. IRDeadCodeEliminationOptions options; options.keepExportsAlive = true; options.keepLayoutsAlive = true; eliminateDeadCode(module, options); + JVPDerivativeContext context(module, sink); + return context.processModule(); } diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index afb132af7..aeb6d4ea1 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -58,6 +58,8 @@ INST(Nop, nop, 0, 0) INST(AttributedType, Attributed, 0, 0) INST(ResultType, Result, 2, 0) + INST(DifferentialPairType, DiffPair, 1, 0) + /* BindExistentialsTypeBase */ // A `BindExistentials<B, T0,w0, T1,w1, ...>` represents @@ -265,6 +267,10 @@ INST(undefined, undefined, 0, 0) // INST(DefaultConstruct, defaultConstruct, 0, 0) +INST(MakeDifferentialPair, MakeDiffPair, 2, 0) +INST(DifferentialPairGetDifferential, GetDifferential, 1, 0) +INST(DifferentialPairGetPrimal, GetPrimal, 1, 0) + INST(Specialize, specialize, 2, 0) INST(lookup_interface_method, lookup_interface_method, 2, 0) INST(GetSequentialID, GetSequentialID, 1, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index f5d5a86ac..2e2dbed5a 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -1841,6 +1841,27 @@ struct IRGetTupleElement : IRInst IRInst* getElementIndex() { return getOperand(1); } }; +// An Instruction that creates a differential pair value from a +// primal and differential. +struct IRMakeDifferentialPair : IRInst +{ + IR_LEAF_ISA(MakeDifferentialPair) + IRInst* getPrimalValue() { return getOperand(0); } + IRInst* getDifferentialValue() { return getOperand(1); } +}; + +struct IRDifferentialPairGetDifferential : IRInst +{ + IR_LEAF_ISA(DifferentialPairGetDifferential) + IRInst* getBase() { return getOperand(0); } +}; + +struct IRDifferentialPairGetPrimal : IRInst +{ + IR_LEAF_ISA(DifferentialPairGetPrimal) + IRInst* getBase() { return getOperand(0); } +}; + // Constructs an `Result<T,E>` value from an error code. struct IRMakeResultError : IRInst { @@ -2278,6 +2299,10 @@ public: IRInst* rowCount, IRInst* columnCount); + IRDifferentialPairType* getDifferentialPairType( + IRType* valueType, + IRWitnessTable* witnessTable); + IRFuncType* getFuncType( UInt paramCount, IRType* const* paramTypes, @@ -2384,6 +2409,8 @@ public: IRInst* emitJVPDifferentiateInst(IRType* type, IRInst* baseFn); + IRInst* emitMakeDifferentialPair(IRType* type, IRInst* primal, IRInst* differential); + IRInst* emitSpecializeInst( IRType* type, IRInst* genericVal, diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 86586c2e8..c66f0d555 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -2741,6 +2741,17 @@ namespace Slang operands); } + IRDifferentialPairType* IRBuilder::getDifferentialPairType( + IRType* valueType, + IRWitnessTable* witnessTable) + { + IRInst* operands[] = { valueType, witnessTable }; + return (IRDifferentialPairType*)getType( + kIROp_DifferentialPairType, + sizeof(operands) / sizeof(operands[0]), + operands); + } + IRFuncType* IRBuilder::getFuncType( UInt paramCount, IRType* const* paramTypes, @@ -3043,6 +3054,15 @@ namespace Slang return inst; } + IRInst* IRBuilder::emitMakeDifferentialPair(IRType* type, IRInst* primal, IRInst* differential) + { + IRInst* args[] = {primal, differential}; + auto inst = createInstWithTrailingArgs<IRMakeDifferentialPair>( + this, kIROp_MakeDifferentialPair, type, 2, args); + addInst(inst); + return inst; + } + IRInst* IRBuilder::emitSpecializeInst( IRType* type, IRInst* genericVal, diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index 3faa5884c..47a724def 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -1255,6 +1255,13 @@ SIMPLE_IR_TYPE(TypeKind, Kind); // SIMPLE_IR_TYPE(GenericKind, Kind) +struct IRDifferentialPairType : IRType +{ + IRType* getValueType() { return (IRType*)getOperand(0); } + + IR_LEAF_ISA(DifferentialPairType) +}; + struct IRVectorType : IRType { IRType* getElementType() { return (IRType*)getOperand(0); } diff --git a/source/slang/slang-stdlib.cpp b/source/slang/slang-stdlib.cpp index 2bbf3c1c0..628a075e3 100644 --- a/source/slang/slang-stdlib.cpp +++ b/source/slang/slang-stdlib.cpp @@ -257,4 +257,21 @@ namespace Slang #endif return hlslLibraryCode; } + + String Session::getAutodiffLibraryCode() + { +#if !defined(SLANG_DISABLE_STDLIB_SOURCE) + if (autodiffLibraryCode.getLength() > 0) + return autodiffLibraryCode; + + const String path = getStdlibPath(); + + StringBuilder sb; + + #include "diff.meta.slang.h" + + autodiffLibraryCode = sb.ProduceString(); +#endif + return autodiffLibraryCode; + } } diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index b3762e471..dd53c98bf 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -178,8 +178,11 @@ void Session::init() coreLanguageScope = builtinAstBuilder->create<Scope>(); coreLanguageScope->nextSibling = baseLanguageScope; + autodiffLanguageScope = builtinAstBuilder->create<Scope>(); + autodiffLanguageScope->nextSibling = coreLanguageScope; + hlslLanguageScope = builtinAstBuilder->create<Scope>(); - hlslLanguageScope->nextSibling = coreLanguageScope; + hlslLanguageScope->nextSibling = autodiffLanguageScope; slangLanguageScope = builtinAstBuilder->create<Scope>(); slangLanguageScope->nextSibling = hlslLanguageScope; @@ -290,6 +293,7 @@ SlangResult Session::compileStdLib(slang::CompileStdLibFlags compileFlags) // TODO(JS): Could make this return a SlangResult as opposed to exception addBuiltinSource(coreLanguageScope, "core", getCoreLibraryCode()); addBuiltinSource(hlslLanguageScope, "hlsl", getHLSLLibraryCode()); + addBuiltinSource(autodiffLanguageScope, "diff", getAutodiffLibraryCode()); if (compileFlags & slang::CompileStdLibFlag::WriteDocumentation) { @@ -348,6 +352,7 @@ SlangResult Session::loadStdLib(const void* stdLib, size_t stdLibSizeInBytes) // Let's try loading serialized modules and adding them SLANG_RETURN_ON_FAIL(_readBuiltinModule(fileSystem, coreLanguageScope, "core")); SLANG_RETURN_ON_FAIL(_readBuiltinModule(fileSystem, hlslLanguageScope, "hlsl")); + SLANG_RETURN_ON_FAIL(_readBuiltinModule(fileSystem, autodiffLanguageScope, "diff")); return SLANG_OK; } diff --git a/tests/autodiff/arithmetic-jvp.slang b/tests/autodiff/arithmetic-jvp.slang index 4e4d200e1..ddd1a4aa9 100644 --- a/tests/autodiff/arithmetic-jvp.slang +++ b/tests/autodiff/arithmetic-jvp.slang @@ -4,14 +4,17 @@ //TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer<float> outputBuffer; +typedef __DifferentialPair<float> dpfloat; +typedef float.Differential dfloat; + __differentiate_jvp float f(float x) { return x; } -float g_jvp_(float x, float dx) +dpfloat g_jvp_(dpfloat dpx) { - return 2 * dx; + return dpfloat(dpx.p(), 2 * dpx.d()); } [__custom_jvp(g_jvp_)] @@ -37,15 +40,13 @@ __differentiate_jvp float j(float x, float y) void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) { { - float a = 2.0; - float b = 1.5; - float da = 1.0; - float db = 1.0; - - outputBuffer[0] = __jvp(f)(a, da); // Expect: 1 - outputBuffer[1] = __jvp(f)(a, 0.0); // Expect: 0 - outputBuffer[2] = __jvp(g)(a, da); // Expect: 2 - outputBuffer[3] = __jvp(h)(a, b, da, db); // Expect: 8 - outputBuffer[4] = __jvp(j)(a, b, da, db); // Expect: 1 + dpfloat dpa = dpfloat(2.0, 1.0); + dpfloat dpb = dpfloat(1.5, 1.0); + + outputBuffer[0] = __jvp(f)(dpa).d(); // Expect: 1 + outputBuffer[1] = __jvp(f)(dpfloat(dpa.p(), 0.0)).d(); // Expect: 0 + outputBuffer[2] = __jvp(g)(dpa).d(); // Expect: 2 + outputBuffer[3] = __jvp(h)(dpa, dpb).d(); // Expect: 8 + outputBuffer[4] = __jvp(j)(dpa, dpb).d(); // Expect: 1 } } diff --git a/tests/autodiff/generic-jvp.slang b/tests/autodiff/generic-jvp.slang new file mode 100644 index 000000000..48993c21c --- /dev/null +++ b/tests/autodiff/generic-jvp.slang @@ -0,0 +1,30 @@ +//TEST_IGNORE_FILE:(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST_IGNORE_FILE:(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +typedef __DifferentialPair<float> dpfloat; +typedef __DifferentialPair<double> dpdouble; +typedef __DifferentialPair<float3> dpfloat3; + +__generic<T:__BuiltinArithmeticType> +__differentiate_jvp T g(T x) +{ + return x + x; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + { + dpfloat dpa = dpfloat(2.0, 1.0); + dpdouble dpb = dpdouble(1.5, 2.0); + dpfloat3 dpf3 = dpfloat3(float3(1.0, 3.0, 5.0), float3(0.5, 1.5, 2.5)); + + outputBuffer[0] = f(dpa.p()); // Expect: 1 + outputBuffer[1] = __jvp(f)(dpfloat(2.0, 0.0)).d(); // Expect: 0 + outputBuffer[2] = (float)__jvp(f)(dpb).d(); // Expect: 2 + outputBuffer[3] = __jvp(f)(dpf3).d().y; // Expect: 1.5 + } +} diff --git a/tests/autodiff/inout-parameters-jvp.slang b/tests/autodiff/inout-parameters-jvp.slang index 40e9d30ca..ba04c6b65 100644 --- a/tests/autodiff/inout-parameters-jvp.slang +++ b/tests/autodiff/inout-parameters-jvp.slang @@ -4,6 +4,7 @@ //TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer<float> outputBuffer; +typedef __DifferentialPair<float> dpfloat; __differentiate_jvp void g(float x, float y, inout float z) { @@ -30,14 +31,16 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) float dy = 0.5; float dz = 2.5; - __jvp(h)(x, y, z, dx, dy, dz); + dpfloat dpz = dpfloat(z, dz); - outputBuffer[0] = dz; // Expect: 12.0 - outputBuffer[1] = z; // Expect: 1.0 + __jvp(h)(dpfloat(x, dx), dpfloat(y, dy), dpz); - __jvp(g)(x, y, z, dx, dy, dz); + outputBuffer[0] = dpz.d(); // Expect: 12.0 + outputBuffer[1] = dpz.p(); // Expect: 6.75 - outputBuffer[2] = dz; // Expect: 21.5 - outputBuffer[3] = z; // Expect: 1.0 + __jvp(g)(dpfloat(x, dx), dpfloat(y, dy), dpz); + + outputBuffer[2] = dpz.d(); // Expect: 21.5 + outputBuffer[3] = dpz.p(); // Expect: 12.5 }
\ No newline at end of file diff --git a/tests/autodiff/inout-parameters-jvp.slang.expected.txt b/tests/autodiff/inout-parameters-jvp.slang.expected.txt index c48ef7bf6..324de53ca 100644 --- a/tests/autodiff/inout-parameters-jvp.slang.expected.txt +++ b/tests/autodiff/inout-parameters-jvp.slang.expected.txt @@ -1,5 +1,5 @@ type: float 12.0 -1.0 +6.75 21.5 -1.0
\ No newline at end of file +12.5
\ No newline at end of file diff --git a/tests/autodiff/local-redecl-custom-jvp.slang b/tests/autodiff/local-redecl-custom-jvp.slang index 2bc7cd582..6241a8bf5 100644 --- a/tests/autodiff/local-redecl-custom-jvp.slang +++ b/tests/autodiff/local-redecl-custom-jvp.slang @@ -3,11 +3,16 @@ //TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer<float> outputBuffer; +typedef __DifferentialPair<float> dpfloat; +typedef float.Differential dfloat; + import test_intrinsics; -float my_pow_jvp(float x, float n, float dx, float dn) +dpfloat my_pow_jvp(dpfloat x, dpfloat n) { - return dx * n * pow(x, n-1) + dn * pow(x, n) * log(x); + return dpfloat( + pow(x.p(), n.p()), + x.d() * n.p() * pow(x.p(), n.p()-1) + n.d() * pow(x.p(), n.p()) * log(x.p())); } [__custom_jvp(my_pow_jvp)] @@ -17,12 +22,12 @@ float _pow(float, float); void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) { { - float a = 5.0; - float n = 2; - float da = 1.0; - float dn = 0; + dpfloat dpa = dpfloat(5.0, 1.0); + dpfloat dpn = dpfloat(2, 0.0); - outputBuffer[0] = __jvp(_pow)(a, n, da, dn); // Expect: 10.0 - outputBuffer[1] = __jvp(_pow)(a, n, 0.0, 1.0); // Expect: 40.23595 + outputBuffer[0] = __jvp(_pow)(dpa, dpn).d(); // Expect: 10.0 + outputBuffer[1] = __jvp(_pow)( + dpfloat(dpa.p(), 0.0), + dpfloat(dpn.p(), 1.0)).d(); // Expect: 40.23595 } } diff --git a/tests/autodiff/nested-jvp.slang b/tests/autodiff/nested-jvp.slang index 222396ec8..baebeee56 100644 --- a/tests/autodiff/nested-jvp.slang +++ b/tests/autodiff/nested-jvp.slang @@ -4,29 +4,34 @@ //TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer<float> outputBuffer; +typedef __DifferentialPair<float> dpfloat; +typedef __DifferentialPair<float3> dpfloat3; + [__custom_jvp(pow_jvp)] float pow_(float x, float n) { return pow<float>(x, n); } - [__custom_jvp(max_jvp)] float max_(float x, float y) { return max<float>(x, y); } - -float pow_jvp(float x, float n, float dx, float dn) +dpfloat pow_jvp(dpfloat x, dpfloat n) { - return dx * n * pow(x, n-1) + ((dn != 0.0) ? (dn * pow(x, n) * log(x)) : 0.0); + return dpfloat( + pow(x.p(), n.p()), + x.d() * n.p() * pow(x.p(), n.p()-1) + + ((n.d() != 0.0) ? (n.d() * pow(x.p(), n.p()) * log(x.p())) : 0.0)); } - -float max_jvp(float x, float y, float dx, float dy) +dpfloat max_jvp(dpfloat x, dpfloat y) { - return (x > y) ? dx : dy; + return dpfloat( + max(x.p(), y.p()), + (x.p() > y.p()) ? x.d() : y.d()); } @@ -53,7 +58,10 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) float3 d_f90 = float3(0.9, 0.9, 0.9); float d_cosTheta = 1.0; - outputBuffer[0] = __jvp(fresnel)(f0, f90, cosTheta, d_f0, d_f90, d_cosTheta).y; // Expect: -0.031250 + outputBuffer[0] = __jvp(fresnel)( + dpfloat3(f0, d_f0), + dpfloat3(f90, d_f90), + dpfloat(cosTheta, d_cosTheta)).d().y; // Expect: -0.031250 float a = 1.0; float b = -0.4; @@ -63,8 +71,16 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) float db = -1.0; float dc = 0.2; - outputBuffer[1] = __jvp(g)(a, b, c, da, db, dc); // Expect: -0.24375 - outputBuffer[2] = g(a, b, c); // Expect: 0.95625 - outputBuffer[3] = __jvp(g)(a, b, 3.0, da, db, dc); // Expect: -0.4; + outputBuffer[1] = __jvp(g)( + dpfloat(a, da), + dpfloat(b, db), + dpfloat(c, dc)).d(); // Expect: -0.24375 + + outputBuffer[2] = g(a, b, c); // Expect: 0.95625 + + outputBuffer[3] = __jvp(g)( + dpfloat(a, da), + dpfloat(b, db), + dpfloat(3.0, dc)).d(); // Expect: -0.4; } } diff --git a/tests/autodiff/out-parameters-jvp.slang b/tests/autodiff/out-parameters-jvp.slang index 58c6cfeb0..b243d4fb5 100644 --- a/tests/autodiff/out-parameters-jvp.slang +++ b/tests/autodiff/out-parameters-jvp.slang @@ -4,6 +4,8 @@ //TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer<float> outputBuffer; +typedef __DifferentialPair<float> dpfloat; + __differentiate_jvp void h(float x, float y, out float result) { float m = x + y; @@ -20,9 +22,9 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) float dx = 1.0; float dy = 0.5; - float dresult = 0.0f; - __jvp(h)(x, y, dx, dy, dresult); + dpfloat dresult; + __jvp(h)(dpfloat(x, dx), dpfloat(y, dy), dresult); - outputBuffer[0] = dresult; // Expect: 9.5 + outputBuffer[0] = dresult.d(); // Expect: 9.5 }
\ No newline at end of file diff --git a/tests/autodiff/test-intrinsics-jvp.slang b/tests/autodiff/test-intrinsics-jvp.slang new file mode 100644 index 000000000..333c89189 --- /dev/null +++ b/tests/autodiff/test-intrinsics-jvp.slang @@ -0,0 +1,17 @@ +//TEST_IGNORE_FILE: + +__exported import test_intrinsics; + +[__custom_jvp(pow_jvp)] +float pow_(float x, float n); +float pow_jvp(float x, float n, float dx, float dn) +{ + return dx * n * pow(x, n-1) + dn * pow(x, n) * log(x); +} + +[__custom_jvp(max_jvp)] +float max_(float x, float y); +float max_jvp(float x, float y, float dx, float dy) +{ + return (x > y) ? dx : dy; +}
\ No newline at end of file diff --git a/tests/autodiff/vector-arithmetic-jvp.slang b/tests/autodiff/vector-arithmetic-jvp.slang index 2b43f1752..393cc18ec 100644 --- a/tests/autodiff/vector-arithmetic-jvp.slang +++ b/tests/autodiff/vector-arithmetic-jvp.slang @@ -4,6 +4,10 @@ //TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer<float> outputBuffer; +typedef __DifferentialPair<float2> dpfloat2; +typedef __DifferentialPair<float3> dpfloat3; +typedef __DifferentialPair<float4> dpfloat4; + __differentiate_jvp float3 f(float3 x) { return x; @@ -37,6 +41,7 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) float3 a = float3(2.0, 2.0, 2.0); float3 b = float3(1.5, 1.5, 1.5); float3 da = float3(1.0, 1.0, 1.0); + //dpfloat3 dpa = dpfloat3(a, da); float2 a2 = float2(2.0, 1.0); float2 b2 = float2(1.5, -2.0); @@ -44,9 +49,18 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) float4 a4 = float4(2.0, 1.0, 0.0, 2.0); float4 b4 = float4(1.5, -2.0, 1.0, 1.5); - outputBuffer[0] = __jvp(f)(a, da).z; // Expect: 1 - outputBuffer[1] = __jvp(g)(a, b, da, float3(2.0, 1.0, 0.0)).y; // Expect: 8 - outputBuffer[2] = __jvp(h)(a2, b2, float2(1.0, 0.0), float2(1.0, 1.0)).x; // Expect: 8 - outputBuffer[3] = __jvp(j)(a4, b4, float4(1.0), float4(2.0)).w; // Expect: 9 + outputBuffer[0] = __jvp(f)(dpfloat3(a, da)).d().z; // Expect: 1 + + outputBuffer[1] = __jvp(g)( + dpfloat3(a, da), + dpfloat3(b, float3(2.0, 1.0, 0.0))).d().y; // Expect: 8 + + outputBuffer[2] = __jvp(h)( + dpfloat2(a2, float2(1.0, 0.0)), + dpfloat2(b2, float2(1.0, 1.0))).d().x; // Expect: 8 + + outputBuffer[3] = __jvp(j)( + dpfloat4(a4, float4(1.0)), + dpfloat4(b4, float4(2.0))).d().w; // Expect: 9 } } diff --git a/tests/autodiff/vector-swizzle-jvp.slang b/tests/autodiff/vector-swizzle-jvp.slang index 6722b54dc..775c0140e 100644 --- a/tests/autodiff/vector-swizzle-jvp.slang +++ b/tests/autodiff/vector-swizzle-jvp.slang @@ -4,6 +4,10 @@ //TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer RWStructuredBuffer<float> outputBuffer; +typedef __DifferentialPair<float2> dpfloat2; +typedef __DifferentialPair<float3> dpfloat3; +typedef __DifferentialPair<float4> dpfloat4; + __differentiate_jvp float2 f(float3 x) { return x.zy; @@ -23,16 +27,16 @@ void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) float3 a = float3(2.0, 2.0, 2.0); float3 da = float3(1.0, 0.5, 1.0); - outputBuffer[0] = __jvp(f)(a, da).x; // Expect: 1 - outputBuffer[1] = __jvp(f)(a, da).y; // Expect: 0.5 + outputBuffer[0] = __jvp(f)(dpfloat3(a, da)).d().x; // Expect: 1 + outputBuffer[1] = __jvp(f)(dpfloat3(a, da)).d().y; // Expect: 0.5 float3 x = float3(0.5, 2.0, 0.5); float4 y = float4(-1.5, 1.0, 4.0, 2.0); float3 dx = float3(1.0, 0.0, -1.0); float4 dy = float4(0.0, 0.5, -0.25, 1.0); - outputBuffer[2] = __jvp(g)(x, y, dx, dy).x; // Expect: -2.25 - outputBuffer[3] = __jvp(g)(x, y, dx, dy).y; // Expect: 0.5 + outputBuffer[2] = __jvp(g)(dpfloat3(x, dx), dpfloat4(y, dy)).d().x; // Expect: -2.25 + outputBuffer[3] = __jvp(g)(dpfloat3(x, dx), dpfloat4(y, dy)).d().y; // Expect: 0.5 } } |
