summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2022-08-05 13:19:20 -0400
committerGitHub <noreply@github.com>2022-08-05 13:19:20 -0400
commit2db8c15c04f2aade49636e42f0adee636afb3b73 (patch)
tree774758a9f854ddf655f6c46765a3ef8ca1950857 /source
parent12a846e8facf090aaeb68fcabf55867f5eaed747 (diff)
Added a new differential type system and various improvements (#2343)
* Merge slang-ir-diff-jvp.cpp * Added support and tests for other float vector types * Added swizzle test and code to handle it (tests failing currently) * Fixed one test, the other is still pending * Fixed instruction cloning logic to avoid modifying original function * Fixed an issue with custom 'pow_jvp' and added support for vector contructor * Minor update to comments * Fixed support for division * Fixed an issue with uninitialized diagnostic sink * Moved derivative processing to after mandatory inlining. Skip instructions that don't have side-effects and aren't used by anything. * WIP: Handling unconditional control flow and multi-block functions * Support for unconditional multi-block functions * Added a dead code elimination step to the derivative pass * Changed name of 'hasNoSideEffects()' * Refactored variable names * Added initial IR defs for new type system * Added necessary logic for semantic checking * Overhauled type system to use builtin pair types and conform to the IDifferentiable interface * Automatically replace IRDifferentiablePairType to a custom IRStructType * Added generics handling by expanding the conformance context functionality and allowing for type parameters * Minor fix: early return in processPairTypes() * Minor fixes to differentiable resolution on generic types * Added new instructions for differential pairs. Basic tests work now. Looking into generic types. * Adjusted most tests to the new type system. OutType and InOutType are still not properly working. * Updated __jvp to produce both primal and differential output * Moved autodiff related declarations to diff.meta.slang * Refactored variable names * Added initial IR defs for new type system * Added necessary logic for semantic checking * Overhauled type system to use builtin pair types and conform to the IDifferentiable interface * Automatically replace IRDifferentiablePairType to a custom IRStructType * Added generics handling by expanding the conformance context functionality and allowing for type parameters * Minor fix: early return in processPairTypes() * Minor fixes to differentiable resolution on generic types * Added new instructions for differential pairs. Basic tests work now. Looking into generic types. * Adjusted most tests to the new type system. OutType and InOutType are still not properly working. * Updated __jvp to produce both primal and differential output * Moved autodiff related declarations to diff.meta.slang * Removed external changes * Cleanup the transcription logic: each case returns a pair of insts for the primal and differential computation.
Diffstat (limited to 'source')
-rw-r--r--source/slang/core.meta.slang9
-rw-r--r--source/slang/diff.meta.slang70
-rw-r--r--source/slang/slang-ast-builder.cpp26
-rw-r--r--source/slang/slang-ast-builder.h4
-rw-r--r--source/slang/slang-ast-type.h15
-rw-r--r--source/slang/slang-check-expr.cpp66
-rw-r--r--source/slang/slang-check-impl.h9
-rwxr-xr-xsource/slang/slang-compiler.h3
-rw-r--r--source/slang/slang-ir-diff-jvp.cpp1184
-rw-r--r--source/slang/slang-ir-inst-defs.h6
-rw-r--r--source/slang/slang-ir-insts.h27
-rw-r--r--source/slang/slang-ir.cpp20
-rw-r--r--source/slang/slang-ir.h7
-rw-r--r--source/slang/slang-stdlib.cpp17
-rw-r--r--source/slang/slang.cpp7
15 files changed, 1089 insertions, 381 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index 862595b90..41f066486 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -71,10 +71,6 @@ syntax snorm : SNormModifier;
///
syntax __extern_cpp : ExternCppModifier;
-/// Modifer to mark a function for forward-mode differentiation.
-/// i.e. the compiler will automatically generate a new function
-/// that computes the jacobian-vector product of the original.
-syntax __differentiate_jvp : JVPDerivativeModifier;
/// A type that can be used as an operand for builtins
[sealed]
@@ -697,7 +693,6 @@ ${{{{
}
}}}}
-
//@ public:
/// Sampling state for filtered texture fetches.
@@ -2321,7 +2316,3 @@ attribute_syntax [noinline] : NoInlineAttribute;
__attributeTarget(StructDecl)
attribute_syntax [payload] : PayloadAttribute;
-
-// Custom JVP Function reference
-__attributeTarget(FuncDecl)
-attribute_syntax [__custom_jvp(function)] : CustomJVPAttribute;
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang
new file mode 100644
index 000000000..e604140ae
--- /dev/null
+++ b/source/slang/diff.meta.slang
@@ -0,0 +1,70 @@
+
+/// Modifer to mark a function for forward-mode differentiation.
+/// i.e. the compiler will automatically generate a new function
+/// that computes the jacobian-vector product of the original.
+syntax __differentiate_jvp : JVPDerivativeModifier;
+
+// Custom JVP Function reference
+__attributeTarget(FuncDecl)
+attribute_syntax [__custom_jvp(function)] : CustomJVPAttribute;
+
+//@ public:
+
+ /// Interface to denote types as differentiable.
+ /// Allows for user-specified differential types as
+ /// well as automatic generation, for when the associated type
+ /// hasn't been declared explicitly.
+__magic_type(DifferentiableType)
+interface IDifferentiable
+{
+ associatedtype Differential;
+};
+
+ /// Pair type that serves to wrap the primal and
+ /// differential types of an arbitrary type T.
+__generic<T : IDifferentiable>
+__magic_type(DifferentialPairType)
+__intrinsic_type($(kIROp_DifferentialPairType))
+struct __DifferentialPair
+{
+
+ __intrinsic_op($(kIROp_MakeDifferentialPair))
+ __init(T _primal, T.Differential _differential);
+
+ __intrinsic_op($(kIROp_DifferentialPairGetDifferential))
+ T.Differential d();
+
+ T.Differential getDifferential()
+ {
+ return d();
+ }
+
+ __intrinsic_op($(kIROp_DifferentialPairGetPrimal))
+ T p();
+
+ T getPrimal()
+ {
+ return p();
+ }
+};
+
+// Add extensions for the standard types
+extension float : IDifferentiable
+{
+ typedef float Differential;
+}
+
+extension vector<float, 3> : IDifferentiable
+{
+ typedef vector<float, 3> Differential;
+}
+
+extension vector<float, 2> : IDifferentiable
+{
+ typedef vector<float, 2> Differential;
+}
+
+extension vector<float, 4> : IDifferentiable
+{
+ typedef vector<float, 4> Differential;
+}
diff --git a/source/slang/slang-ast-builder.cpp b/source/slang/slang-ast-builder.cpp
index 7f23f5492..868763f76 100644
--- a/source/slang/slang-ast-builder.cpp
+++ b/source/slang/slang-ast-builder.cpp
@@ -265,7 +265,7 @@ VectorExpressionType* ASTBuilder::getVectorType(
IntVal* elementCount)
{
auto vectorGenericDecl = as<GenericDecl>(m_sharedASTBuilder->findMagicDecl("Vector"));
-
+
auto vectorTypeDecl = vectorGenericDecl->inner;
auto substitutions = create<GenericSubstitution>();
@@ -278,6 +278,30 @@ VectorExpressionType* ASTBuilder::getVectorType(
return as<VectorExpressionType>(DeclRefType::create(this, declRef));
}
+DifferentialPairType* ASTBuilder::getDifferentialPairType(Type* valueType, Witness* conformanceWitness)
+{
+ auto genericDecl = dynamicCast<GenericDecl>(m_sharedASTBuilder->findMagicDecl("DifferentialPairType"));
+
+ auto typeDecl = genericDecl->inner;
+
+ auto substitutions = create<GenericSubstitution>();
+ substitutions->genericDecl = genericDecl;
+ substitutions->args.add(valueType);
+ substitutions->args.add(conformanceWitness);
+
+ auto declRef = DeclRef<Decl>(typeDecl, substitutions);
+ auto rsType = DeclRefType::create(this, declRef);
+
+ return as<DifferentialPairType>(rsType);
+}
+
+DeclRef<InterfaceDecl> ASTBuilder::getDifferentiableInterface()
+{
+ DeclRef<InterfaceDecl> declRef;
+ declRef.decl = dynamicCast<InterfaceDecl>(m_sharedASTBuilder->findMagicDecl("DifferentiableType"));
+ return declRef;
+}
+
DeclRef<Decl> ASTBuilder::getBuiltinDeclRef(const char* builtinMagicTypeName, ConstArrayView<Val*> genericArgs)
{
DeclRef<Decl> declRef;
diff --git a/source/slang/slang-ast-builder.h b/source/slang/slang-ast-builder.h
index d8a8679ee..3c0303e70 100644
--- a/source/slang/slang-ast-builder.h
+++ b/source/slang/slang-ast-builder.h
@@ -164,6 +164,10 @@ public:
VectorExpressionType* getVectorType(Type* elementType, IntVal* elementCount);
+ DifferentialPairType* getDifferentialPairType(Type* valueType, Witness* conformanceWitness);
+
+ DeclRef<InterfaceDecl> getDifferentiableInterface();
+
DeclRef<Decl> getBuiltinDeclRef(const char* builtinMagicTypeName, ConstArrayView<Val*> genericArgs);
Type* getAndType(Type* left, Type* right);
diff --git a/source/slang/slang-ast-type.h b/source/slang/slang-ast-type.h
index b82c6b182..5d4e42bfb 100644
--- a/source/slang/slang-ast-type.h
+++ b/source/slang/slang-ast-type.h
@@ -437,6 +437,21 @@ protected:
};
+// A differential pair type, e.g., `__DifferentialPair<T>`
+class DifferentialPairType : public ArithmeticExpressionType
+{
+ SLANG_AST_CLASS(DifferentialPairType)
+
+ // The type of vector elements.
+ // As an invariant, this should be a basic type or an alias.
+ Type* baseType = nullptr;
+};
+
+class DifferentiableType : public BuiltinType
+{
+ SLANG_AST_CLASS(DifferentiableType)
+};
+
// A vector type, e.g., `vector<T,N>`
class VectorExpressionType : public ArithmeticExpressionType
{
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index b7f99c4e7..a787af211 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -1526,48 +1526,42 @@ namespace Slang
return expr;
}
- // This function proceses primal params (i.e params of the inner function that is being
- // differentiated) that need to be carried over to the function signature for the JVP
- // function. (eg. out types can be discarded)
- //
- Type* primalToInputType(ASTBuilder*, Type* primalType)
- {
- if (auto primalOutType = as<OutType>(primalType))
- return nullptr;
- else if (auto primalInOutType = as<InOutType>(primalType))
- return primalInOutType->getValueType();
-
- return primalType;
- }
- Type* primalToJVPParamType(ASTBuilder* builder, Type* primalType)
+ Type* SemanticsVisitor::_toDifferentialParamType(ASTBuilder* builder, Type* primalType)
{
- // Only float and vector<float> types can be differentiated for now.
-
- if (primalType->equals(builder->getFloatType()))
- return primalType;
- else if (auto primalVectorType = as<VectorExpressionType>(primalType))
- {
- if (auto jvpElementType = primalToJVPParamType(builder, primalVectorType->elementType))
- return builder->getVectorType(jvpElementType, primalVectorType->elementCount);
- }
- else if (auto primalOutType = as<OutType>(primalType))
+ // Check for type modifiers like 'out' and 'inout'. We need to differentiate the
+ // nested type.
+ //
+ if (auto primalOutType = as<OutType>(primalType))
{
- return builder->getOutType(primalToJVPParamType(builder, primalOutType->getValueType()));
+ return builder->getOutType(_toDifferentialParamType(builder, primalOutType->getValueType()));
}
else if (auto primalInOutType = as<InOutType>(primalType))
{
- return builder->getInOutType(primalToJVPParamType(builder, primalInOutType->getValueType()));
+ return builder->getInOutType(_toDifferentialParamType(builder, primalInOutType->getValueType()));
}
- return nullptr;
+
+ // Get a reference to the builtin 'IDifferentiable' interface
+ auto differentiableInterface = builder->getDifferentiableInterface();
+
+ // Check if the provided type inherits from IDifferentiable.
+ // If not, return the original type.
+ if (auto conformanceWitness = as<Witness>(tryGetInterfaceConformanceWitness(primalType, differentiableInterface)))
+ return builder->getDifferentialPairType(primalType, conformanceWitness);
+ else
+ return primalType;
+
}
- Type* primalToJVPReturnType(ASTBuilder* builder, Type* primalType)
+ Type* SemanticsVisitor::_toJVPReturnType(ASTBuilder* builder, Type* primalType)
{
- if(auto jvpType = primalToJVPParamType(builder, primalType))
- return jvpType;
+ if (auto conformanceWitness =
+ as<Witness>(tryGetInterfaceConformanceWitness(
+ primalType,
+ builder->getDifferentiableInterface())))
+ return builder->getDifferentialPairType(primalType, conformanceWitness);
else
- return builder->getVoidType();
+ return primalType;
}
Expr* SemanticsExprVisitor::visitJVPDifferentiateExpr(JVPDifferentiateExpr* expr)
@@ -1588,7 +1582,7 @@ namespace Slang
// The JVP return type is float if primal return type is float
// void otherwise.
//
- jvpType->resultType = primalToJVPReturnType(astBuilder, primalType->getResultType());
+ jvpType->resultType = _toJVPReturnType(astBuilder, primalType->getResultType());
// No support for differentiating function that throw errors, for now.
SLANG_ASSERT(primalType->errorType->equals(astBuilder->getBottomType()));
@@ -1596,13 +1590,7 @@ namespace Slang
for (UInt i = 0; i < primalType->getParamCount(); i++)
{
- if(auto primalInputType = primalToInputType(astBuilder, primalType->getParamType(i)))
- jvpType->paramTypes.add(primalInputType);
- }
-
- for (UInt i = 0; i < primalType->getParamCount(); i++)
- {
- if(auto jvpParamType = primalToJVPParamType(astBuilder, primalType->getParamType(i)))
+ if(auto jvpParamType = _toDifferentialParamType(astBuilder, primalType->getParamType(i)))
jvpType->paramTypes.add(jvpParamType);
}
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index ebc6e05d5..f2f7a6bd1 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -673,6 +673,15 @@ namespace Slang
void _validateCircularVarDefinition(VarDeclBase* varDecl);
bool shouldSkipChecking(Decl* decl, DeclCheckState state);
+
+ // Auto-diff convenience functions for translating primal types to differential types.
+ Type* _toDifferentialParamType(ASTBuilder* builder, Type* primalType);
+
+ // Translate a return type to the return type of a forward-mode differentiated
+ // function.
+ //
+ Type* _toJVPReturnType(ASTBuilder* builder, Type* primalType);
+
public:
bool ValuesAreEqual(
diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h
index 5ba121742..fd338361e 100755
--- a/source/slang/slang-compiler.h
+++ b/source/slang/slang-compiler.h
@@ -2950,6 +2950,7 @@ namespace Slang
Scope* coreLanguageScope = nullptr;
Scope* hlslLanguageScope = nullptr;
Scope* slangLanguageScope = nullptr;
+ Scope* autodiffLanguageScope = nullptr;
ModuleDecl* baseModuleDecl = nullptr;
List<RefPtr<Module>> stdlibModules;
@@ -2981,10 +2982,12 @@ namespace Slang
String slangLibraryCode;
String hlslLibraryCode;
String glslLibraryCode;
+ String autodiffLibraryCode;
String getStdlibPath();
String getCoreLibraryCode();
String getHLSLLibraryCode();
+ String getAutodiffLibraryCode();
RefPtr<SharedASTBuilder> m_sharedASTBuilder;
diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp
index fd1d0086d..554a407ee 100644
--- a/source/slang/slang-ir-diff-jvp.cpp
+++ b/source/slang/slang-ir-diff-jvp.cpp
@@ -9,19 +9,349 @@
namespace Slang
{
+template<typename P, typename D>
+struct Pair
+{
+ P primal;
+ D differential;
+
+ Pair(P primal, D differential) : primal(primal), differential(differential)
+ {}
+};
+
+typedef Pair<IRInst*, IRInst*> InstPair;
+
+struct DifferentiableTypeConformanceContext
+{
+ Dictionary<IRInst*, IRInst*> witnessTableMap;
+
+ IRInst* inst = nullptr;
+
+ // A reference to the builtin IDifferentiable interface type.
+ // We use this to look up all the other types (and type exprs)
+ // that conform to a base type.
+ //
+ IRInterfaceType* differentiableInterfaceType = nullptr;
+
+ // The struct key for the 'Differential' associated type
+ // defined inside IDifferential. We use this to lookup the differential
+ // type in the conformance table associated with the concrete type.
+ //
+ IRStructKey* differentialAssocTypeStructKey = nullptr;
+
+ // Modules that don't use differentiable types
+ // won't have the IDifferentiable interface type available.
+ // Set to false to indicate that we are uninitialized.
+ //
+ bool isInterfaceAvailable = false;
+
+ // For handling generic blocks, we use a parent pointer to allow
+ // looking up types in all relevant scopes.
+ DifferentiableTypeConformanceContext* parent = nullptr;
+
+ DifferentiableTypeConformanceContext(DifferentiableTypeConformanceContext* parent, IRInst* inst) : parent(parent), inst(inst)
+ {
+ if (parent)
+ {
+ differentiableInterfaceType = parent->differentiableInterfaceType;
+ differentialAssocTypeStructKey = parent->differentialAssocTypeStructKey;
+ isInterfaceAvailable = parent->isInterfaceAvailable;
+ }
+ else
+ {
+ differentiableInterfaceType = as<IRInterfaceType>(findDifferentiableInterface());
+ if (differentiableInterfaceType)
+ {
+ differentialAssocTypeStructKey = findDifferentialTypeStructKey();
+
+ if (differentialAssocTypeStructKey)
+ isInterfaceAvailable = true;
+ }
+ }
+
+ if (isInterfaceAvailable)
+ {
+ // Load all witness tables corresponding to the IDifferentiable interface.
+ loadWitnessTablesForInterface(differentiableInterfaceType);
+ }
+ }
+
+ DifferentiableTypeConformanceContext(IRInst* inst) :
+ DifferentiableTypeConformanceContext(nullptr, inst)
+ {}
+
+ // Lookup a witness table for the concreteType. One should exist if concreteType
+ // inherits (successfully) from IDifferentiable.
+ //
+ IRInst* lookUpConformanceForType(IRInst* type)
+ {
+ SLANG_ASSERT(isInterfaceAvailable);
+
+ if (witnessTableMap.ContainsKey(type))
+ return witnessTableMap[type];
+ else if (parent)
+ return parent->lookUpConformanceForType(type);
+ else
+ return nullptr;
+ }
+
+ // Lookup and return the 'Differential' type declared in the concrete type
+ // in order to conform to the IDifferentiable interface.
+ // Note that inside a generic block, this will be a witness table lookup instruction
+ // that gets resolved during the specialization pass.
+ //
+ IRInst* getDifferentialForType(IRBuilder* builder, IRType* origType)
+ {
+ SLANG_ASSERT(isInterfaceAvailable);
+
+ if (auto conformance = lookUpConformanceForType(origType))
+ {
+ if (auto witnessTable = as<IRWitnessTable>(conformance))
+ {
+ for (auto entry : witnessTable->getEntries())
+ {
+ if (entry->getRequirementKey() == differentialAssocTypeStructKey)
+ return as<IRType>(entry->getSatisfyingVal());
+ }
+ }
+ else if (auto witnessTableParam = as<IRParam>(conformance))
+ {
+ return builder->emitLookupInterfaceMethodInst(
+ builder->getTypeKind(),
+ witnessTableParam,
+ differentialAssocTypeStructKey);
+ }
+ }
+
+ return nullptr;
+ }
+
+ private:
+
+ IRInst* findDifferentiableInterface()
+ {
+ if (auto module = as<IRModuleInst>(inst))
+ {
+ for (auto globalInst : module->getGlobalInsts())
+ {
+ // TODO: This seems like a particularly dangerous way to look for an interface.
+ // See if we can lower IDifferentiable to a separate IR inst.
+ //
+ if (globalInst->getOp() == kIROp_InterfaceType &&
+ as<IRInterfaceType>(globalInst)->findDecoration<IRNameHintDecoration>()->getName() == "IDifferentiable")
+ {
+ return globalInst;
+ }
+ }
+ }
+ return nullptr;
+ }
+
+ IRStructKey* findDifferentialTypeStructKey()
+ {
+ if (as<IRModuleInst>(inst) && differentiableInterfaceType)
+ {
+ // Assume for now that IDifferentiable has exactly one field: the 'Differential' associated type.
+ SLANG_ASSERT(differentiableInterfaceType->getOperandCount() == 1);
+ if (auto entry = as<IRInterfaceRequirementEntry>(differentiableInterfaceType->getOperand(0)))
+ return as<IRStructKey>(entry->getRequirementKey());
+ else
+ {
+ SLANG_UNEXPECTED("IDifferentiable interface entry unexpected type");
+ }
+ }
+
+ return nullptr;
+ }
+
+ void loadWitnessTablesForInterface(IRInst* interfaceType)
+ {
+
+ if (auto module = as<IRModuleInst>(inst))
+ {
+ for (auto globalInst : module->getGlobalInsts())
+ {
+ if (globalInst->getOp() == kIROp_WitnessTable &&
+ cast<IRWitnessTableType>(globalInst->getDataType())->getConformanceType() ==
+ interfaceType)
+ {
+ // TODO: Can we have multiple conformances for the same pair of types?
+ // TODO: Can type instrs be duplicated (i.e. two different float types)? And if they are duplicated, can
+ // we supply the dictionary with a custom equality rule that uses 'type1->equals(type2)'
+ witnessTableMap.Add(as<IRWitnessTable>(globalInst)->getConcreteType(), globalInst);
+ }
+ }
+ }
+ else if (auto generic = as<IRGeneric>(inst))
+ {
+ List<IRParam*> typeParams;
+
+ auto genericParam = generic->getFirstParam();
+ while (genericParam)
+ {
+ if (as<IRTypeType>(genericParam->getDataType()))
+ {
+ typeParams.add(genericParam);
+ }
+ else
+ break;
+
+ genericParam = genericParam->getNextParam();
+ }
+
+ UCount tableIndex = 0;
+ while (genericParam)
+ {
+ SLANG_ASSERT(!as<IRTypeType>(genericParam->getDataType()));
+ if (auto witnessTableType = as<IRWitnessTableType>(genericParam->getDataType()))
+ {
+ if (witnessTableType->getConformanceType() == differentiableInterfaceType)
+ witnessTableMap.Add(typeParams[tableIndex], genericParam);
+ }
+ else
+ break;
+
+ tableIndex += 1;
+ genericParam = genericParam->getNextParam();
+ }
+
+ }
+
+ }
+
+};
+
+struct DifferentialPairTypeBuilder
+{
+
+ DifferentialPairTypeBuilder(DifferentiableTypeConformanceContext* diffConformanceContext) :
+ diffConformanceContext(diffConformanceContext)
+ {}
+
+ IRInst* emitPrimalFieldAccess(IRBuilder* builder, IRInst* baseInst)
+ {
+ if (auto basePairStructType = as<IRStructType>(baseInst->getDataType()))
+ {
+ auto primalField = as<IRStructField>(basePairStructType->getFirstChild());
+ SLANG_ASSERT(primalField);
+
+ return as<IRFieldExtract>(builder->emitFieldExtract(
+ primalField->getFieldType(),
+ baseInst,
+ primalField->getKey()
+ ));
+ }
+ else if (auto ptrType = as<IRPtrTypeBase>(baseInst->getDataType()))
+ {
+ if (auto pairStructType = as<IRStructType>(ptrType->getValueType()))
+ {
+ auto primalField = as<IRStructField>(pairStructType->getFirstChild());
+ SLANG_ASSERT(primalField);
+
+ return as<IRFieldAddress>(builder->emitFieldAddress(
+ builder->getPtrType(primalField->getFieldType()),
+ baseInst,
+ primalField->getKey()
+ ));
+ }
+ }
+ else
+ {
+ SLANG_UNREACHABLE("basePairType must be an IRStructType or PtrType<IRStructType>");
+ }
+ return nullptr;
+ }
+
+ IRInst* emitDiffFieldAccess(IRBuilder* builder, IRInst* baseInst)
+ {
+ if (auto basePairStructType = as<IRStructType>(baseInst->getDataType()))
+ {
+ auto diffField = as<IRStructField>(basePairStructType->getFirstChild()->getNextInst());
+ SLANG_ASSERT(diffField);
+
+ return as<IRFieldExtract>(builder->emitFieldExtract(
+ diffField->getFieldType(),
+ baseInst,
+ diffField->getKey()
+ ));
+ }
+ else if (auto ptrType = as<IRPtrTypeBase>(baseInst->getDataType()))
+ {
+ if (auto pairStructType = as<IRStructType>(ptrType->getValueType()))
+ {
+ auto diffField = as<IRStructField>(pairStructType->getFirstChild()->getNextInst());
+ SLANG_ASSERT(diffField);
+
+ return as<IRFieldAddress>(builder->emitFieldAddress(
+ builder->getPtrType(diffField->getFieldType()),
+ baseInst,
+ diffField->getKey()
+ ));
+ }
+ }
+ else
+ {
+ SLANG_UNREACHABLE("basePairType must be an IRStructType or PtrType<IRStructType>");
+ }
+ return nullptr;
+ }
+
+ IRStructType* _createDiffPairType(IRBuilder* builder, IRType* origBaseType)
+ {
+ if (auto diffBaseType = diffConformanceContext->getDifferentialForType(builder, origBaseType))
+ {
+ auto diffPairType = builder->createStructType();
+
+ // Create a keys for the primal and differential fields.
+ IRStructKey* origKey = builder->createStructKey();
+ builder->addNameHintDecoration(origKey, UnownedTerminatedStringSlice("primal"));
+ builder->createStructField(diffPairType, origKey, origBaseType);
+
+ IRStructKey* diffKey = builder->createStructKey();
+ builder->addNameHintDecoration(diffKey, UnownedTerminatedStringSlice("differential"));
+ builder->createStructField(diffPairType, diffKey, (IRType*)(diffBaseType));
+
+ return diffPairType;
+ }
+ return nullptr;
+ }
+
+ IRStructType* getOrCreateDiffPairType(IRBuilder* builder, IRType* origBaseType)
+ {
+ if (pairTypeCache.ContainsKey(origBaseType))
+ return pairTypeCache[origBaseType];
+
+ auto pairType = _createDiffPairType(builder, origBaseType);
+ pairTypeCache.Add(origBaseType, pairType);
+
+ return pairType;
+ }
+
+ Dictionary<IRType*, IRStructType*> pairTypeCache;
+
+ DifferentiableTypeConformanceContext* diffConformanceContext;
+
+};
+
struct JVPTranscriber
{
// Stores the mapping of arbitrary 'R-value' instructions to instructions that represent
// their differential values.
- Dictionary<IRInst*, IRInst*> instMapD;
+ Dictionary<IRInst*, IRInst*> instMapD;
// Cloning environment to hold mapping from old to new copies for the primal
// instructions.
- IRCloneEnv cloneEnv;
+ IRCloneEnv cloneEnv;
// Diagnostic sink for error messages.
- DiagnosticSink* sink;
+ DiagnosticSink* sink;
+
+ // Type conformance information.
+ DifferentiableTypeConformanceContext* diffConformanceContext;
+
+ // Builder to help with creating and accessing the 'DifferentiablePair<T>' struct
+ DifferentialPairTypeBuilder* pairBuilder;
DiagnosticSink* getSink()
{
@@ -29,284 +359,318 @@ struct JVPTranscriber
return sink;
}
- void mapDifferentialInst(IRInst* instP, IRInst* instD)
+ void mapDifferentialInst(IRInst* origInst, IRInst* diffInst)
{
- instMapD.Add(instP, instD);
+ instMapD.Add(origInst, diffInst);
}
- IRInst* getDifferentialInst(IRInst* instP)
+ void mapPrimalInst(IRInst* origInst, IRInst* primalInst)
{
- return instMapD[instP];
+ if (cloneEnv.mapOldValToNew.ContainsKey(origInst) && cloneEnv.mapOldValToNew[origInst] != primalInst)
+ {
+ getSink()->diagnose(origInst->sourceLoc,
+ Diagnostics::internalCompilerError,
+ "inconsistent primal instruction for original");
+ }
+ else
+ {
+ cloneEnv.mapOldValToNew[origInst] = primalInst;
+ }
}
- IRInst* getDifferentialInst(IRInst* instP, IRInst* defaultInst)
+ IRInst* lookupDiffInst(IRInst* origInst)
{
- return (hasDifferentialInst(instP)) ? instMapD[instP] : defaultInst;
+ return instMapD[origInst];
}
- bool hasDifferentialInst(IRInst* instP)
+ IRInst* lookupDiffInst(IRInst* origInst, IRInst* defaultInst)
{
- return instMapD.ContainsKey(instP);
+ return (hasDifferentialInst(origInst)) ? instMapD[origInst] : defaultInst;
}
- IRFuncType* differentiateFunctionType(IRBuilder* builder, IRFuncType* funcType)
+ bool hasDifferentialInst(IRInst* origInst)
{
- List<IRType*> parameterTypesD;
- IRType* returnTypeD;
+ return instMapD.ContainsKey(origInst);
+ }
- // Add all primal parameters to the list.
- for (UIndex i = 0; i < funcType->getParamCount(); i++)
- {
- // TODO(sai): Move this check to a separate function.
- if (!as<IROutType>(funcType->getParamType(i)))
- parameterTypesD.add(funcType->getParamType(i));
+ IRInst* lookupPrimalInst(IRInst* origInst)
+ {
+ return cloneEnv.mapOldValToNew[origInst];
+ }
+
+ IRInst* lookupPrimalInst(IRInst* origInst, IRInst* defaultInst)
+ {
+ return (hasPrimalInst(origInst)) ? lookupPrimalInst(origInst) : defaultInst;
+ }
+
+ bool hasPrimalInst(IRInst* origInst)
+ {
+ return cloneEnv.mapOldValToNew.ContainsKey(origInst);
+ }
+
+ IRInst* findOrTranscribeDiffInst(IRBuilder* builder, IRInst* origInst)
+ {
+ if (!hasDifferentialInst(origInst))
+ {
+ transcribe(builder, origInst);
+ SLANG_ASSERT(hasDifferentialInst(origInst));
+ }
+
+ return lookupDiffInst(origInst);
+ }
+
+ IRInst* findOrTranscribePrimalInst(IRBuilder* builder, IRInst* origInst)
+ {
+ if (!hasPrimalInst(origInst))
+ {
+ transcribe(builder, origInst);
+ SLANG_ASSERT(hasPrimalInst(origInst));
}
- // Add differential versions for the types we support.
+ return lookupPrimalInst(origInst);
+ }
+
+ IRFuncType* differentiateFunctionType(IRBuilder* builder, IRFuncType* funcType)
+ {
+ List<IRType*> newParameterTypes;
+ IRType* diffReturnType;
+
for (UIndex i = 0; i < funcType->getParamCount(); i++)
- {
- if (auto typeD = differentiateType(builder, funcType->getParamType(i)))
- parameterTypesD.add(typeD);
+ {
+ auto origType = funcType->getParamType(i);
+ if (auto diffPairType = tryGetDiffPairType(builder, origType))
+ newParameterTypes.add(diffPairType);
+ else
+ newParameterTypes.add(origType);
}
- // Transcribe return type.
+ // Transcribe return type to a pair.
// This will be void if the primal return type is non-differentiable.
//
- returnTypeD = differentiateType(builder, funcType->getResultType());
- if (!returnTypeD)
- returnTypeD = builder->getVoidType();
+ if (auto returnPairType = tryGetDiffPairType(builder, funcType->getResultType()))
+ diffReturnType = returnPairType;
+ else
+ diffReturnType = builder->getVoidType();
- return builder->getFuncType(parameterTypesD, returnTypeD);
+ return builder->getFuncType(newParameterTypes, diffReturnType);
}
- IRType* differentiateType(IRBuilder* builder, IRType* typeP)
+ IRType* differentiateType(IRBuilder* builder, IRType* origType)
{
- switch (typeP->getOp())
+ switch (origType->getOp())
{
case kIROp_HalfType:
case kIROp_FloatType:
case kIROp_DoubleType:
- return builder->getType(typeP->getOp());
case kIROp_VectorType:
- // TODO(sai): Call differentiateType() on typeP.
- return as<IRVectorType>(typeP);
+ return (IRType*)(diffConformanceContext->getDifferentialForType(builder, origType));
case kIROp_OutType:
- return builder->getOutType(differentiateType(builder, as<IROutType>(typeP)->getValueType()));
+ return builder->getOutType(differentiateType(builder, as<IROutType>(origType)->getValueType()));
case kIROp_InOutType:
- return builder->getInOutType(differentiateType(builder, as<IRInOutType>(typeP)->getValueType()));
+ return builder->getInOutType(differentiateType(builder, as<IRInOutType>(origType)->getValueType()));
default:
return nullptr;
}
}
- IRInst* differentiateParam(IRBuilder* builder, IRParam* paramP)
- {
- if (IRType* typeD = differentiateType(builder, paramP->getFullType()))
- {
- IRParam* paramD = builder->emitParam(typeD);
-
- auto nameHintD = getJVPVarName(paramP);
- if (nameHintD.getLength() > 0)
- builder->addNameHintDecoration(paramD, nameHintD.getUnownedSlice());
-
- SLANG_ASSERT(paramD);
- return paramD;
- }
- return nullptr;
- }
-
- IRInst* emitInputParam(IRBuilder* builder, IRParam* paramP)
+ IRType* tryGetDiffPairType(IRBuilder* builder, IRType* origType)
{
- // Convert primal 'inout' types into pure input types, because a
- // JVP transformed function must never have primal side-effects.
+ // If this is a PtrType (out, inout, etc..), then create diff pair from
+ // value type and re-apply the appropropriate PtrType wrapper.
//
- if (auto inoutTypeP = as<IRInOutType>(paramP->getDataType()))
+ if (auto origPtrType = as<IRPtrTypeBase>(origType))
{
- auto newParamP = builder->emitParam(inoutTypeP->getValueType());
- cloneEnv.mapOldValToNew.Add(paramP, newParamP);
- cloneInstDecorationsAndChildren(&cloneEnv, builder->getSharedBuilder(), paramP, newParamP);
-
- return newParamP;
- }
- else if (as<IROutType>(paramP->getDataType()))
- {
- getSink()->diagnose(paramP->sourceLoc,
- Diagnostics::unexpected,
- "encountered unexpected output parameter");
- return nullptr;
+ if (auto diffPairValueType = tryGetDiffPairType(builder, origPtrType->getValueType()))
+ return builder->getPtrType(origType->getOp(), diffPairValueType);
+ else
+ return nullptr;
}
- else
- return as<IRParam>(cloneInst(&cloneEnv, builder, paramP));
+
+ return pairBuilder->getOrCreateDiffPairType(builder, origType);
}
- List<IRParam*> transcribeParams(IRBuilder* builder, IRInstList<IRParam> paramListP)
+ InstPair transcribeParam(IRBuilder* builder, IRParam* origParam)
{
- // Clone (and emit) all the primal parameters.
- List<IRParam*> newParamListP;
- for (auto paramP : paramListP)
+ if (auto diffPairType = tryGetDiffPairType(builder, origParam->getFullType()))
{
- if(isPurelyFunctional(builder, paramP))
- newParamListP.add(as<IRParam>(emitInputParam(builder, paramP)));
+ IRParam* diffPairParam = builder->emitParam(diffPairType);
+
+ auto diffPairVarName = makeDiffPairName(origParam);
+ if (diffPairVarName.getLength() > 0)
+ builder->addNameHintDecoration(diffPairParam, diffPairVarName.getUnownedSlice());
+
+ SLANG_ASSERT(diffPairParam);
+
+ return InstPair(
+ pairBuilder->emitPrimalFieldAccess(builder, diffPairParam),
+ pairBuilder->emitDiffFieldAccess(builder, diffPairParam));
}
+
+ return InstPair(
+ cloneInst(&cloneEnv, builder, origParam),
+ nullptr);
+ }
- // Now emit differentials.
- List<IRParam*> newParamListD;
- for (auto paramP : paramListP)
+ // Returns "d<var-name>" to use as a name hint for variables and parameters.
+ // If no primal name is available, returns a blank string.
+ //
+ String getJVPVarName(IRInst* origVar)
+ {
+ if (auto namehintDecoration = origVar->findDecoration<IRNameHintDecoration>())
{
- IRParam* paramD = as<IRParam>(differentiateParam(builder, paramP));
- mapDifferentialInst(findCloneForOperand(&cloneEnv, paramP), paramD);
- newParamListD.add(paramD);
+ return ("d" + String(namehintDecoration->getName()));
}
- return newParamListD;
+ return String("");
}
- // Returns "d<var-name>" to use as a name hint for variables and parameters.
+ // Returns "dp<var-name>" to use as a name hint for parameters.
// If no primal name is available, returns a blank string.
//
- String getJVPVarName(IRInst* varP)
+ String makeDiffPairName(IRInst* origVar)
{
- if (auto namehintDecoration = varP->findDecoration<IRNameHintDecoration>())
+ if (auto namehintDecoration = origVar->findDecoration<IRNameHintDecoration>())
{
- return ("d" + String(namehintDecoration->getName()));
+ return ("dp" + String(namehintDecoration->getName()));
}
return String("");
}
- IRInst* differentiateVar(IRBuilder* builder, IRVar* varP)
+ InstPair transcribeVar(IRBuilder* builder, IRVar* origVar)
{
- if (IRType* typeD = differentiateType(builder, varP->getDataType()->getValueType()))
+ if (IRType* diffType = differentiateType(builder, origVar->getDataType()->getValueType()))
{
- IRVar* varD = builder->emitVar(typeD);
- SLANG_ASSERT(varD);
+ IRVar* diffVar = builder->emitVar(diffType);
+ SLANG_ASSERT(diffVar);
- auto nameHintD = getJVPVarName(varP);
- if (nameHintD.getLength() > 0)
- builder->addNameHintDecoration(varD, nameHintD.getUnownedSlice());
+ auto diffNameHint = getJVPVarName(origVar);
+ if (diffNameHint.getLength() > 0)
+ builder->addNameHintDecoration(diffVar, diffNameHint.getUnownedSlice());
- return varD;
+ return InstPair(cloneInst(&cloneEnv, builder, origVar), diffVar);
}
- return nullptr;
+
+ return InstPair(cloneInst(&cloneEnv, builder, origVar), nullptr);
}
- IRInst* differentiateBinaryArith(IRBuilder* builder, IRInst* arith)
+ InstPair transcribeBinaryArith(IRBuilder* builder, IRInst* origArith)
{
- SLANG_ASSERT(arith->getOperandCount() == 2);
+ SLANG_ASSERT(origArith->getOperandCount() == 2);
+
+ IRInst* primalArith = cloneInst(&cloneEnv, builder, origArith);
- auto leftP = arith->getOperand(0);
- auto rightP = arith->getOperand(1);
+ auto origLeft = origArith->getOperand(0);
+ auto origRight = origArith->getOperand(1);
- auto leftD = getDifferentialInst(leftP);
- auto rightD = getDifferentialInst(rightP);
+ auto primalLeft = findOrTranscribePrimalInst(builder, origLeft);
+ auto primalRight = findOrTranscribePrimalInst(builder, origRight);
+
+ auto diffLeft = findOrTranscribeDiffInst(builder, origLeft);
+ auto diffRight = findOrTranscribeDiffInst(builder, origRight);
- auto leftZero = builder->getFloatValue(leftP->getDataType(), 0.0);
- auto rightZero = builder->getFloatValue(rightP->getDataType(), 0.0);
+ auto leftZero = builder->getFloatValue(origLeft->getDataType(), 0.0);
+ auto rightZero = builder->getFloatValue(origRight->getDataType(), 0.0);
- if (leftD || rightD)
+ if (diffLeft || diffRight)
{
- leftD = leftD ? leftD : leftZero;
- rightD = rightD ? rightD : rightZero;
+ diffLeft = diffLeft ? diffLeft : leftZero;
+ diffRight = diffRight ? diffRight : rightZero;
- // Might have to do special-case handling for non-scalar types,
- // like float3 or float3x3
- //
- auto resultType = arith->getDataType();
- switch(arith->getOp())
+ auto resultType = origArith->getDataType();
+ switch(origArith->getOp())
{
case kIROp_Add:
- return builder->emitAdd(resultType, leftD, rightD);
+ return InstPair(primalArith, builder->emitAdd(resultType, diffLeft, diffRight));
case kIROp_Mul:
- return builder->emitAdd(resultType,
- builder->emitMul(resultType, leftD, rightP),
- builder->emitMul(resultType, leftP, rightD));
+ return InstPair(primalArith, builder->emitAdd(resultType,
+ builder->emitMul(resultType, diffLeft, primalRight),
+ builder->emitMul(resultType, primalLeft, diffRight)));
case kIROp_Sub:
- return builder->emitSub(resultType, leftD, rightD);
+ return InstPair(primalArith, builder->emitSub(resultType, diffLeft, diffRight));
case kIROp_Div:
- return builder->emitDiv(resultType,
+ return InstPair(primalArith, builder->emitDiv(resultType,
builder->emitSub(
resultType,
- builder->emitMul(resultType, leftD, rightP),
- builder->emitMul(resultType, leftP, rightD)),
+ builder->emitMul(resultType, diffLeft, primalRight),
+ builder->emitMul(resultType, primalLeft, diffRight)),
builder->emitMul(
- rightP->getDataType(), rightP, rightP
- ));
+ primalRight->getDataType(), primalRight, primalRight
+ )));
default:
- getSink()->diagnose(arith->sourceLoc,
+ getSink()->diagnose(origArith->sourceLoc,
Diagnostics::unimplemented,
"this arithmetic instruction cannot be differentiated");
}
}
- return nullptr;
+ return InstPair(primalArith, nullptr);
}
- IRInst* differentiateLoad(IRBuilder* builder, IRLoad* loadP)
+ InstPair transcribeLoad(IRBuilder* builder, IRLoad* origLoad)
{
- auto ptrP = loadP->getPtr();
- if (as<IRVar>(ptrP) || as<IRParam>(ptrP))
- {
- // If the loaded parameter has a differential version,
- // emit a load instruction for the differential parameter.
- // Otherwise, emit nothing since there's nothing to load.
- //
- if (auto ptrD = getDifferentialInst(ptrP, nullptr))
- {
- IRLoad* loadD = as<IRLoad>(builder->emitLoad(ptrD));
- SLANG_ASSERT(loadD);
- return loadD;
- }
- return nullptr;
+ auto origPtr = origLoad->getPtr();
+
+ auto primalLoad = cloneInst(&cloneEnv, builder, origLoad);
+
+ if (auto diffPtr = lookupDiffInst(origPtr, nullptr))
+ {
+ IRLoad* diffLoad = as<IRLoad>(builder->emitLoad(diffPtr));
+ SLANG_ASSERT(diffLoad);
+
+ return InstPair(primalLoad, diffLoad);
}
- else
- getSink()->diagnose(loadP->sourceLoc,
- Diagnostics::unimplemented,
- "this load instruction cannot be differentiated");
- return nullptr;
+ return InstPair(primalLoad, nullptr);
}
- IRInst* differentiateStore(IRBuilder* builder, IRStore* storeP)
+ InstPair transcribeStore(IRBuilder* builder, IRStore* origStore)
{
- IRInst* storeLocation = storeP->getPtr();
- IRInst* storeVal = storeP->getVal();
- if (as<IRVar>(storeLocation) || as<IRParam>(storeLocation))
- {
- // If the stored value has a differential version,
- // emit a store instruction for the differential parameter.
- // Otherwise, emit nothing since there's nothing to load.
- //
- IRInst* storeValD = getDifferentialInst(storeVal);
- IRInst* storeLocationD = getDifferentialInst(storeLocation);
- if (storeValD && storeLocationD)
- {
- IRStore* storeD = as<IRStore>(
- builder->emitStore(storeLocationD, storeValD));
- SLANG_ASSERT(storeD);
- return storeD;
- }
- return nullptr;
+ IRInst* origStoreLocation = origStore->getPtr();
+ IRInst* origStoreVal = origStore->getVal();
+
+ auto primalStore = cloneInst(&cloneEnv, builder, origStore);
+
+ auto diffStoreLocation = lookupDiffInst(origStoreLocation, nullptr);
+ auto diffStoreVal = lookupDiffInst(origStoreVal, nullptr);
+
+ // If the stored value has a differential version,
+ // emit a store instruction for the differential parameter.
+ // Otherwise, emit nothing since there's nothing to load.
+ //
+ if (diffStoreLocation && diffStoreVal)
+ {
+ IRStore* diffStore = as<IRStore>(
+ builder->emitStore(diffStoreLocation, diffStoreVal));
+ SLANG_ASSERT(diffStore);
+
+ return InstPair(primalStore, diffStore);
}
- else
- getSink()->diagnose(storeP->sourceLoc,
- Diagnostics::unimplemented,
- "this store instruction cannot be differentiated");
- return nullptr;
+
+ return InstPair(primalStore, nullptr);
}
- IRInst* differentiateReturn(IRBuilder* builder, IRReturn* returnP)
+ InstPair transcribeReturn(IRBuilder* builder, IRReturn* origReturn)
{
- IRInst* returnVal = returnP->getVal();
- if (auto returnValD = getDifferentialInst(returnVal, nullptr))
+ IRInst* origReturnVal = origReturn->getVal();
+
+ if (auto pairType = tryGetDiffPairType(builder, origReturnVal->getDataType()))
{
- IRReturn* returnD = as<IRReturn>(builder->emitReturn(returnValD));
- SLANG_ASSERT(returnD);
- return returnD;
+ IRInst* primalReturnVal = findOrTranscribePrimalInst(builder, origReturnVal);
+
+ IRInst* diffReturnVal = findOrTranscribeDiffInst(builder, origReturnVal);
+ if(!diffReturnVal)
+ diffReturnVal = getZeroOfType(builder, origReturnVal->getDataType());
+
+ auto diffPair = builder->emitMakeDifferentialPair(pairType, primalReturnVal, diffReturnVal);
+ IRReturn* pairReturn = as<IRReturn>(builder->emitReturn(diffPair));
+ return InstPair(pairReturn, pairReturn);
}
else
{
// If the differential return value is not available, emit a
// void return.
- return builder->emitReturn();
+ IRInst* voidReturn = builder->emitReturn();
+ return InstPair(voidReturn, voidReturn);
}
}
@@ -314,68 +678,72 @@ struct JVPTranscriber
// instruction, we check to make sure that the nested instr is a constant
// and then return nullptr. Literals do not need to be differentiated.
//
- IRInst* differentiateConstruct(IRBuilder*, IRInst* consP)
+ InstPair transcribeConstruct(IRBuilder* builder, IRInst* origConstruct)
{
- if (as<IRConstant>(consP->getOperand(0)) && consP->getOperandCount() == 1)
- return nullptr;
+ IRInst* primalConstruct = cloneInst(&cloneEnv, builder, origConstruct);
+
+ if (as<IRConstant>(origConstruct->getOperand(0)) && origConstruct->getOperandCount() == 1)
+ return InstPair(primalConstruct, nullptr);
else
- getSink()->diagnose(consP->sourceLoc,
+ getSink()->diagnose(origConstruct->sourceLoc,
Diagnostics::unimplemented,
"this construct instruction cannot be differentiated");
- return nullptr;
+
+ return InstPair(primalConstruct, nullptr);
}
// Differentiating a call instruction here is primarily about generating
// an appropriate call list based on whichever parameters have differentials
// in the current transcription context.
- // Note(sai): Currently we don't look at modifiers (in, out, const etc..) in the function
- // type, and so only support 'plain' parameters. We need to validte this somewhere to
- // avoid weird behaviour
//
- IRInst* differentiateCall(IRBuilder* builder, IRCall* callP)
+ InstPair transcribeCall(IRBuilder* builder, IRCall* origCall)
{
- if (auto calleeP = as<IRFunc>(callP->getCallee()))
+ if (auto origCallee = as<IRFunc>(origCall->getCallee()))
{
// Build the differential callee
- IRInst* calleeD = builder->emitJVPDifferentiateInst(
- differentiateFunctionType(builder, as<IRFuncType>(calleeP->getFullType())),
- calleeP);
+ IRInst* diffCall = builder->emitJVPDifferentiateInst(
+ differentiateFunctionType(builder, as<IRFuncType>(origCallee->getFullType())),
+ origCallee);
List<IRInst*> args;
- // Go over the parameter list and all primal arguments.
- for (UIndex ii = 0; ii < callP->getArgCount(); ii++)
+ // Go over the parameter list and create pairs for each input (if required)
+ for (UIndex ii = 0; ii < origCall->getArgCount(); ii++)
{
- args.add(callP->getArg(ii));
- }
+ auto origArg = origCall->getArg(ii);
+ auto primalArg = findOrTranscribePrimalInst(builder, origArg);
+ SLANG_ASSERT(primalArg);
- {
- IRParam* param = calleeP->getFirstParam();
- // Go over the parameter list again and arguments for types that need differentials.
- for (UIndex ii = 0; ii < callP->getArgCount(); ii++)
+ auto origType = origArg->getDataType();
+ if (auto pairType = tryGetDiffPairType(builder, origType))
{
- // Look the parameter up in the callee's signature. If it requires a derivative, proceed.
- // Otherwise, continue.
- //
- if (differentiateType(builder, param->getDataType()))
- {
- // If the corresponding argument does not have a differential, create and place a
- // 0 argument.
- //
- auto argP = callP->getArg(ii);
- if (auto argD = getDifferentialInst(argP, nullptr))
- args.add(argD);
- else
- args.add(getZeroOfType(builder, argP->getDataType()));
- }
+
+ auto diffArg = findOrTranscribeDiffInst(builder, origArg);
- param = param->getNextParam();
+ // TODO(sai): This part is flawed. Replace with a call to the
+ // 'zero()' interface method.
+ if (!diffArg)
+ diffArg = getZeroOfType(builder, origType);
+
+ auto diffPair = builder->emitMakeDifferentialPair(pairType, primalArg, diffArg);
+
+ args.add(diffPair);
+ }
+ else
+ {
+ // Add original/primal argument.
+ args.add(primalArg);
}
}
- return builder->emitCallInst(differentiateType(builder, callP->getFullType()),
- calleeD,
- args);
+ auto callInst = builder->emitCallInst(
+ tryGetDiffPairType(builder, origCall->getFullType()),
+ diffCall,
+ args);
+
+ return InstPair(
+ pairBuilder->emitPrimalFieldAccess(builder, callInst),
+ pairBuilder->emitDiffFieldAccess(builder, callInst));
}
else
{
@@ -384,31 +752,40 @@ struct JVPTranscriber
// differentiate such calls safely.
// TODO(sai): Should probably get checked in the front-end.
//
- getSink()->diagnose(callP->sourceLoc,
+ getSink()->diagnose(origCall->sourceLoc,
Diagnostics::internalCompilerError,
"attempting to differentiate unresolved callee");
}
- return nullptr;
+
+ return InstPair(nullptr, nullptr);
}
- IRInst* differentiateSwizzle(IRBuilder* builder, IRSwizzle* swizzleP)
+ InstPair transcribeSwizzle(IRBuilder* builder, IRSwizzle* origSwizzle)
{
- if (auto baseD = getDifferentialInst(swizzleP->getBase(), nullptr))
+ IRInst* primalSwizzle = cloneInst(&cloneEnv, builder, origSwizzle);
+
+ if (auto diffBase = lookupDiffInst(origSwizzle->getBase(), nullptr))
{
List<IRInst*> swizzleIndices;
- for (UIndex ii = 0; ii < swizzleP->getElementCount(); ii++)
- swizzleIndices.add(swizzleP->getElementIndex(ii));
+ for (UIndex ii = 0; ii < origSwizzle->getElementCount(); ii++)
+ swizzleIndices.add(origSwizzle->getElementIndex(ii));
- return builder->emitSwizzle(differentiateType(builder, swizzleP->getDataType()),
- baseD,
- swizzleP->getElementCount(),
- swizzleIndices.getBuffer());
+ return InstPair(
+ primalSwizzle,
+ builder->emitSwizzle(
+ differentiateType(builder, origSwizzle->getDataType()),
+ diffBase,
+ origSwizzle->getElementCount(),
+ swizzleIndices.getBuffer()));
}
- return nullptr;
+
+ return InstPair(primalSwizzle, nullptr);
}
- IRInst* differentiateByPassthrough(IRBuilder* builder, IRInst* origInst)
+ InstPair transcribeByPassthrough(IRBuilder* builder, IRInst* origInst)
{
+ IRInst* primalInst = cloneInst(&cloneEnv, builder, origInst);
+
UCount operandCount = origInst->getOperandCount();
List<IRInst*> diffOperands;
@@ -419,20 +796,22 @@ struct JVPTranscriber
// Otherwise, abandon the differentiation attempt and assume that origInst
// cannot (or does not need to) be differentiated.
//
- if (auto diffInst = getDifferentialInst(origInst->getOperand(ii), nullptr))
+ if (auto diffInst = lookupDiffInst(origInst->getOperand(ii), nullptr))
diffOperands.add(diffInst);
else
- return nullptr;
+ return InstPair(primalInst, nullptr);
}
- return builder->emitIntrinsicInst(
- differentiateType(builder, origInst->getDataType()),
- origInst->getOp(),
- operandCount,
- diffOperands.getBuffer());
+ return InstPair(
+ primalInst,
+ builder->emitIntrinsicInst(
+ differentiateType(builder, origInst->getDataType()),
+ origInst->getOp(),
+ operandCount,
+ diffOperands.getBuffer()));
}
- IRInst* handleControlFlow(IRBuilder* builder, IRInst* origInst)
+ InstPair transcribeControlFlow(IRBuilder* builder, IRInst* origInst)
{
switch(origInst->getOp())
{
@@ -443,17 +822,41 @@ struct JVPTranscriber
if (origBranch->getOperandCount() > 1)
break;
- if (auto diffBlock = getDifferentialInst(origBranch->getTargetBlock(), nullptr))
- return builder->emitBranch(as<IRBlock>(diffBlock));
- else
- return nullptr;
+ IRInst* diffBranch = nullptr;
+
+ if (auto diffBlock = lookupDiffInst(origBranch->getTargetBlock(), nullptr))
+ diffBranch = builder->emitBranch(as<IRBlock>(diffBlock));
+
+ // For now, every block in the original fn must have a corresponding
+ // block to compute both primals and derivatives.
+ SLANG_ASSERT(diffBranch);
+
+ return InstPair(diffBranch, diffBranch);
}
getSink()->diagnose(
origInst->sourceLoc,
Diagnostics::unimplemented,
"attempting to differentiate unhandled control flow");
- return nullptr;
+
+ return InstPair(nullptr, nullptr);
+ }
+
+
+ InstPair transcribeConst(IRBuilder*, IRInst* origInst)
+ {
+ switch(origInst->getOp())
+ {
+ case kIROp_FloatLit:
+ return InstPair(origInst, nullptr);
+ }
+
+ getSink()->diagnose(
+ origInst->sourceLoc,
+ Diagnostics::unimplemented,
+ "attempting to differentiate unhandled const type");
+
+ return InstPair(nullptr, nullptr);
}
// In differential computation, the 'default' differential value is always zero.
@@ -470,6 +873,15 @@ struct JVPTranscriber
return builder->getFloatValue(type, 0.0);
case kIROp_IntType:
return builder->getIntValue(type, 0);
+ case kIROp_VectorType:
+ {
+ IRInst* args[] = {getZeroOfType(builder, as<IRVectorType>(type)->getElementType())};
+ return builder->emitIntrinsicInst(
+ type,
+ kIROp_constructVectorFromScalar,
+ 1,
+ args);
+ }
default:
getSink()->diagnose(type->sourceLoc,
Diagnostics::internalCompilerError,
@@ -478,126 +890,85 @@ struct JVPTranscriber
}
}
- // Logic for whether a primal instruction needs to be replicated
- // in the differential function. We detect and avoid replicating
- // 'side-effect' instructions.
- //
- bool isPurelyFunctional(IRBuilder*, IRInst* instP)
+ IRInst* transcribe(IRBuilder* builder, IRInst* origInst)
{
- if (as<IRTerminatorInst>(instP))
- return false;
- else if (auto paramP = as<IRParam>(instP))
- {
- // Out-type parameters are discarded from the parameter list,
- // since pure JVP functions to not write to primal outputs.
- //
- if (as<IROutType>(paramP->getDataType()))
- return false;
- }
- else if (auto storeP = as<IRStore>(instP))
- {
- IRInst* storeLocation = storeP->getPtr();
+ InstPair pair = transcribeInst(builder, origInst);
- // Writing to a parameter is a side-effect that should be avoided.
- if(as<IRParam>(storeLocation))
- return false;
-
- // If attempting to store to a location without a clone,
- // then this instruction likely has side-effects external to the
- // current function.
- //
- if(!lookUp(&cloneEnv, storeLocation))
- return false;
- }
-
- return true;
- }
-
- IRInst* transcribe(IRBuilder* builder, IRInst* oldInstP)
- {
-
- // Clone the old instruction into the new differential function.
- //
- IRInst* instP = cloneInst(&cloneEnv, builder, oldInstP);
-
- SLANG_ASSERT(instP);
-
- IRInst* instD = differentiateInst(builder, instP);
-
- // In case it's not safe to clone the old instruction,
- // remove it from the graph.
- // For instance, instructions that handle control flow
- // (return statements) shouldn't be replicated.
- //
- if (isPurelyFunctional(builder, oldInstP))
- mapDifferentialInst(instP, instD);
- else
+ if (auto primalInst = pair.primal)
{
- // This inst should never have been used.
- SLANG_ASSERT(instP->firstUse == nullptr);
+ mapPrimalInst(origInst, pair.primal);
- instP->removeAndDeallocate();
- mapDifferentialInst(oldInstP, instD);
+ mapDifferentialInst(origInst, pair.differential);
+ return pair.differential;
}
- return instD;
+ getSink()->diagnose(origInst->sourceLoc,
+ Diagnostics::internalCompilerError,
+ "failed to transcibe instruction");
+ return nullptr;
}
- IRInst* differentiateInst(IRBuilder* builder, IRInst* instP)
+ InstPair transcribeInst(IRBuilder* builder, IRInst* origInst)
{
// Handle common operations
- switch (instP->getOp())
+ switch (origInst->getOp())
{
+ case kIROp_Param:
+ return transcribeParam(builder, as<IRParam>(origInst));
+
case kIROp_Var:
- return differentiateVar(builder, as<IRVar>(instP));
+ return transcribeVar(builder, as<IRVar>(origInst));
case kIROp_Load:
- return differentiateLoad(builder, as<IRLoad>(instP));
+ return transcribeLoad(builder, as<IRLoad>(origInst));
case kIROp_Store:
- return differentiateStore(builder, as<IRStore>(instP));
+ return transcribeStore(builder, as<IRStore>(origInst));
case kIROp_Return:
- return differentiateReturn(builder, as<IRReturn>(instP));
+ return transcribeReturn(builder, as<IRReturn>(origInst));
case kIROp_Add:
case kIROp_Mul:
case kIROp_Sub:
case kIROp_Div:
- return differentiateBinaryArith(builder, instP);
+ return transcribeBinaryArith(builder, origInst);
case kIROp_Construct:
- return differentiateConstruct(builder, instP);
+ return transcribeConstruct(builder, origInst);
case kIROp_Call:
- return differentiateCall(builder, as<IRCall>(instP));
+ return transcribeCall(builder, as<IRCall>(origInst));
case kIROp_swizzle:
- return differentiateSwizzle(builder, as<IRSwizzle>(instP));
+ return transcribeSwizzle(builder, as<IRSwizzle>(origInst));
case kIROp_constructVectorFromScalar:
- return differentiateByPassthrough(builder, instP);
+ return transcribeByPassthrough(builder, origInst);
case kIROp_unconditionalBranch:
case kIROp_conditionalBranch:
- return handleControlFlow(builder, instP);
+ return transcribeControlFlow(builder, origInst);
+
+ case kIROp_FloatLit:
+ return transcribeConst(builder, origInst);
}
// If none of the cases have been hit, check if the instruction is a
// type.
// For now we don't have logic to differentiate types that appear in blocks.
- // So, we ignore them.
+ // So, we clone and avoid differentiating them.
//
- if (as<IRType>(instP))
- return nullptr;
-
+ if (auto origType = as<IRType>(origInst))
+ return InstPair(cloneInst(&cloneEnv, builder, origType), nullptr);
// If we reach this statement, the instruction type is likely unhandled.
- getSink()->diagnose(instP->sourceLoc,
+ getSink()->diagnose(origInst->sourceLoc,
Diagnostics::unimplemented,
"this instruction cannot be differentiated");
- return nullptr;
+
+ return InstPair(nullptr, nullptr);
}
};
@@ -659,8 +1030,19 @@ struct JVPDerivativeContext
IRBuilder builderStorage(sharedBuilderStorage);
IRBuilder* builder = &builderStorage;
- // processMarkedGlobalFunctions(builder);
- return processReferencedFunctions(builder);
+ // Process all JVPDifferentiate instructions (kIROp_JVPDifferentiate), by
+ // generating derivative code for the referenced function.
+ //
+ bool modified = processReferencedFunctions(builder);
+
+ // Replaces IRDifferentialPairType with an auto-generated struct,
+ // IRDifferentialPairGetDifferential with 'differential' field access,
+ // IRDifferentialPairGetPrimal with 'primal' field access, and
+ // IRMakeDifferentialPair with an IRMakeStruct.
+ //
+ modified |= processPairTypes(builder, module->getModuleInst(), (&diffConformanceContextStorage));
+
+ return modified;
}
IRInst* lookupJVPReference(IRInst* primalFunction)
@@ -683,7 +1065,7 @@ struct JVPDerivativeContext
// Keep processing items until the queue is complete.
while (IRInst* workItem = workQueue->pop())
- {
+ {
for(auto child = workItem->getFirstChild(); child; child = child->getNextInst())
{
// Either the child instruction has more children (func/block etc..)
@@ -749,6 +1131,132 @@ struct JVPDerivativeContext
return true;
}
+ IRInst* lowerPairType(IRBuilder* builder, IRType* type, DifferentiableTypeConformanceContext* diffContext)
+ {
+ if (diffContext->isInterfaceAvailable)
+ {
+ if (auto pairType = as<IRDifferentialPairType>(type))
+ {
+ builder->setInsertBefore(pairType);
+
+ auto diffPairStructType = (&pairBuilderStorage)->getOrCreateDiffPairType(
+ builder,
+ pairType->getValueType());
+
+ pairType->replaceUsesWith(diffPairStructType);
+ pairType->removeAndDeallocate();
+
+ return diffPairStructType;
+ }
+ else if (auto loweredStructType = as<IRStructType>(type))
+ {
+ // Already lowered to struct.
+ return loweredStructType;
+ }
+ }
+ return nullptr;
+ }
+
+ IRInst* lowerMakePair(IRBuilder* builder, IRInst* inst, DifferentiableTypeConformanceContext* diffContext)
+ {
+
+ if (auto makePairInst = as<IRMakeDifferentialPair>(inst))
+ {
+ auto diffPairStructType = lowerPairType(builder, makePairInst->getDataType(), diffContext);
+
+ builder->setInsertBefore(makePairInst);
+
+ List<IRInst*> operands;
+ operands.add(makePairInst->getPrimalValue());
+ operands.add(makePairInst->getDifferentialValue());
+
+ auto makeStructInst = builder->emitMakeStruct(as<IRStructType>(diffPairStructType), operands);
+ makePairInst->replaceUsesWith(makeStructInst);
+ makePairInst->removeAndDeallocate();
+
+ return makeStructInst;
+ }
+
+ return nullptr;
+ }
+
+ IRInst* lowerPairAccess(IRBuilder* builder, IRInst* inst, DifferentiableTypeConformanceContext* diffContext)
+ {
+
+ if (auto getDiffInst = as<IRDifferentialPairGetDifferential>(inst))
+ {
+ lowerPairType(builder, getDiffInst->getBase()->getDataType(), diffContext);
+
+ builder->setInsertBefore(getDiffInst);
+
+ auto diffFieldExtract = (&pairBuilderStorage)->emitDiffFieldAccess(builder, getDiffInst->getBase());
+ getDiffInst->replaceUsesWith(diffFieldExtract);
+ getDiffInst->removeAndDeallocate();
+
+ return diffFieldExtract;
+ }
+ else if (auto getPrimalInst = as<IRDifferentialPairGetPrimal>(inst))
+ {
+ lowerPairType(builder, getPrimalInst->getBase()->getDataType(), diffContext);
+
+ builder->setInsertBefore(getPrimalInst);
+
+ auto primalFieldExtract = (&pairBuilderStorage)->emitPrimalFieldAccess(builder, getPrimalInst->getBase());
+ getPrimalInst->replaceUsesWith(primalFieldExtract);
+ getPrimalInst->removeAndDeallocate();
+
+ return primalFieldExtract;
+ }
+
+ return nullptr;
+ }
+
+ bool processPairTypes(IRBuilder* builder, IRInst* instWithChildren, DifferentiableTypeConformanceContext* diffContext)
+ {
+ bool modified = false;
+
+ // Create a new sub-context to scan witness tables inside workItem
+ // (mainly relevant if instWithChildren is a generic scope)
+ //
+ auto subContext = DifferentiableTypeConformanceContext(diffContext, instWithChildren);
+ (&pairBuilderStorage)->diffConformanceContext = (&subContext);
+
+ for (auto child = instWithChildren->getFirstChild(); child; )
+ {
+ // Make sure the builder is at the right level.
+ builder->setInsertInto(instWithChildren);
+
+ auto nextChild = child->getNextInst();
+
+ switch (child->getOp())
+ {
+ case kIROp_DifferentialPairType:
+ lowerPairType(builder, as<IRType>(child), &subContext);
+ break;
+
+ case kIROp_DifferentialPairGetDifferential:
+ case kIROp_DifferentialPairGetPrimal:
+ lowerPairAccess(builder, child, &subContext);
+ break;
+
+ case kIROp_MakeDifferentialPair:
+ lowerMakePair(builder, child, &subContext);
+ break;
+
+ default:
+ if (child->getFirstChild())
+ modified = processPairTypes(builder, child, (&subContext)) | modified;
+ }
+
+ child = nextChild;
+ }
+
+ // Reset the context back to the parent.
+ (&pairBuilderStorage)->diffConformanceContext = diffContext;
+
+ return modified;
+ }
+
// Checks decorators to see if the function should
// be differentiated (kIROp_JVPDerivativeMarkerDecoration)
//
@@ -823,12 +1331,13 @@ struct JVPDerivativeContext
{
auto jvpBlock = builder->emitBlock();
transcriberStorage.mapDifferentialInst(block, jvpBlock);
+ transcriberStorage.mapPrimalInst(block, jvpBlock);
}
// Go back over the blocks, and process the children of each block.
for (auto block = primalFn->getFirstBlock(); block; block = block->getNextBlock())
{
- auto jvpBlock = as<IRBlock>(transcriberStorage.getDifferentialInst(block, block));
+ auto jvpBlock = as<IRBlock>(transcriberStorage.lookupDiffInst(block, block));
SLANG_ASSERT(jvpBlock);
emitJVPBlock(builder, block, jvpBlock);
}
@@ -858,7 +1367,7 @@ struct JVPDerivativeContext
}
IRBlock* emitJVPBlock(IRBuilder* builder,
- IRBlock* primalBlock,
+ IRBlock* origBlock,
IRBlock* jvpBlock = nullptr)
{
JVPTranscriber* transcriber = &(transcriberStorage);
@@ -869,16 +1378,17 @@ struct JVPDerivativeContext
else
builder->setInsertInto(jvpBlock);
- // First transcribe the parameter list. This is done separately because we
- // want all the derivative parameters emitted after the primal parameters
- // rather than interleaved with one another.
- //
- transcriber->transcribeParams(builder, primalBlock->getParams());
+
+ // First transcribe every parameter in the block.
+ for (auto param = origBlock->getFirstParam(); param; param = param->getNextParam())
+ {
+ transcriber->transcribe(builder, param);
+ }
- // Run through every instruction and use the transcriber to generate the appropriate
+ // Then, run through every instruction and use the transcriber to generate the appropriate
// derivative code.
//
- for(auto child = primalBlock->getFirstOrdinaryInst(); child; child = child->getNextInst())
+ for (auto child = origBlock->getFirstOrdinaryInst(); child; child = child->getNextInst())
{
transcriber->transcribe(builder, child);
}
@@ -886,9 +1396,14 @@ struct JVPDerivativeContext
return jvpBlock;
}
- JVPDerivativeContext(IRModule* module, DiagnosticSink* sink) : module(module), sink(sink)
+ JVPDerivativeContext(IRModule* module, DiagnosticSink* sink) :
+ module(module), sink(sink),
+ diffConformanceContextStorage(module->getModuleInst()),
+ pairBuilderStorage(&diffConformanceContextStorage)
{
transcriberStorage.sink = sink;
+ transcriberStorage.diffConformanceContext = &(diffConformanceContextStorage);
+ transcriberStorage.pairBuilder = &(pairBuilderStorage);
}
protected:
@@ -913,7 +1428,14 @@ struct JVPDerivativeContext
// Work queue to hold a stream of instructions that need
// to be checked for references to derivative functions.
- IRWorkQueue workQueueStorage;
+ IRWorkQueue workQueueStorage;
+
+ // Context to find and manage the witness tables for types
+ // implementing `IDifferentiable`
+ DifferentiableTypeConformanceContext diffConformanceContextStorage;
+
+ // Builder for dealing with differential pair types.
+ DifferentialPairTypeBuilder pairBuilderStorage;
};
@@ -923,15 +1445,15 @@ bool processJVPDerivativeMarkers(
IRModule* module,
DiagnosticSink* sink,
IRJVPDerivativePassOptions const&)
-{
- JVPDerivativeContext context(module, sink);
-
+{
// Simplify module to remove dead code.
IRDeadCodeEliminationOptions options;
options.keepExportsAlive = true;
options.keepLayoutsAlive = true;
eliminateDeadCode(module, options);
+ JVPDerivativeContext context(module, sink);
+
return context.processModule();
}
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index afb132af7..aeb6d4ea1 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -58,6 +58,8 @@ INST(Nop, nop, 0, 0)
INST(AttributedType, Attributed, 0, 0)
INST(ResultType, Result, 2, 0)
+ INST(DifferentialPairType, DiffPair, 1, 0)
+
/* BindExistentialsTypeBase */
// A `BindExistentials<B, T0,w0, T1,w1, ...>` represents
@@ -265,6 +267,10 @@ INST(undefined, undefined, 0, 0)
//
INST(DefaultConstruct, defaultConstruct, 0, 0)
+INST(MakeDifferentialPair, MakeDiffPair, 2, 0)
+INST(DifferentialPairGetDifferential, GetDifferential, 1, 0)
+INST(DifferentialPairGetPrimal, GetPrimal, 1, 0)
+
INST(Specialize, specialize, 2, 0)
INST(lookup_interface_method, lookup_interface_method, 2, 0)
INST(GetSequentialID, GetSequentialID, 1, 0)
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index f5d5a86ac..2e2dbed5a 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -1841,6 +1841,27 @@ struct IRGetTupleElement : IRInst
IRInst* getElementIndex() { return getOperand(1); }
};
+// An Instruction that creates a differential pair value from a
+// primal and differential.
+struct IRMakeDifferentialPair : IRInst
+{
+ IR_LEAF_ISA(MakeDifferentialPair)
+ IRInst* getPrimalValue() { return getOperand(0); }
+ IRInst* getDifferentialValue() { return getOperand(1); }
+};
+
+struct IRDifferentialPairGetDifferential : IRInst
+{
+ IR_LEAF_ISA(DifferentialPairGetDifferential)
+ IRInst* getBase() { return getOperand(0); }
+};
+
+struct IRDifferentialPairGetPrimal : IRInst
+{
+ IR_LEAF_ISA(DifferentialPairGetPrimal)
+ IRInst* getBase() { return getOperand(0); }
+};
+
// Constructs an `Result<T,E>` value from an error code.
struct IRMakeResultError : IRInst
{
@@ -2278,6 +2299,10 @@ public:
IRInst* rowCount,
IRInst* columnCount);
+ IRDifferentialPairType* getDifferentialPairType(
+ IRType* valueType,
+ IRWitnessTable* witnessTable);
+
IRFuncType* getFuncType(
UInt paramCount,
IRType* const* paramTypes,
@@ -2384,6 +2409,8 @@ public:
IRInst* emitJVPDifferentiateInst(IRType* type, IRInst* baseFn);
+ IRInst* emitMakeDifferentialPair(IRType* type, IRInst* primal, IRInst* differential);
+
IRInst* emitSpecializeInst(
IRType* type,
IRInst* genericVal,
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 86586c2e8..c66f0d555 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -2741,6 +2741,17 @@ namespace Slang
operands);
}
+ IRDifferentialPairType* IRBuilder::getDifferentialPairType(
+ IRType* valueType,
+ IRWitnessTable* witnessTable)
+ {
+ IRInst* operands[] = { valueType, witnessTable };
+ return (IRDifferentialPairType*)getType(
+ kIROp_DifferentialPairType,
+ sizeof(operands) / sizeof(operands[0]),
+ operands);
+ }
+
IRFuncType* IRBuilder::getFuncType(
UInt paramCount,
IRType* const* paramTypes,
@@ -3043,6 +3054,15 @@ namespace Slang
return inst;
}
+ IRInst* IRBuilder::emitMakeDifferentialPair(IRType* type, IRInst* primal, IRInst* differential)
+ {
+ IRInst* args[] = {primal, differential};
+ auto inst = createInstWithTrailingArgs<IRMakeDifferentialPair>(
+ this, kIROp_MakeDifferentialPair, type, 2, args);
+ addInst(inst);
+ return inst;
+ }
+
IRInst* IRBuilder::emitSpecializeInst(
IRType* type,
IRInst* genericVal,
diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h
index 3faa5884c..47a724def 100644
--- a/source/slang/slang-ir.h
+++ b/source/slang/slang-ir.h
@@ -1255,6 +1255,13 @@ SIMPLE_IR_TYPE(TypeKind, Kind);
//
SIMPLE_IR_TYPE(GenericKind, Kind)
+struct IRDifferentialPairType : IRType
+{
+ IRType* getValueType() { return (IRType*)getOperand(0); }
+
+ IR_LEAF_ISA(DifferentialPairType)
+};
+
struct IRVectorType : IRType
{
IRType* getElementType() { return (IRType*)getOperand(0); }
diff --git a/source/slang/slang-stdlib.cpp b/source/slang/slang-stdlib.cpp
index 2bbf3c1c0..628a075e3 100644
--- a/source/slang/slang-stdlib.cpp
+++ b/source/slang/slang-stdlib.cpp
@@ -257,4 +257,21 @@ namespace Slang
#endif
return hlslLibraryCode;
}
+
+ String Session::getAutodiffLibraryCode()
+ {
+#if !defined(SLANG_DISABLE_STDLIB_SOURCE)
+ if (autodiffLibraryCode.getLength() > 0)
+ return autodiffLibraryCode;
+
+ const String path = getStdlibPath();
+
+ StringBuilder sb;
+
+ #include "diff.meta.slang.h"
+
+ autodiffLibraryCode = sb.ProduceString();
+#endif
+ return autodiffLibraryCode;
+ }
}
diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp
index b3762e471..dd53c98bf 100644
--- a/source/slang/slang.cpp
+++ b/source/slang/slang.cpp
@@ -178,8 +178,11 @@ void Session::init()
coreLanguageScope = builtinAstBuilder->create<Scope>();
coreLanguageScope->nextSibling = baseLanguageScope;
+ autodiffLanguageScope = builtinAstBuilder->create<Scope>();
+ autodiffLanguageScope->nextSibling = coreLanguageScope;
+
hlslLanguageScope = builtinAstBuilder->create<Scope>();
- hlslLanguageScope->nextSibling = coreLanguageScope;
+ hlslLanguageScope->nextSibling = autodiffLanguageScope;
slangLanguageScope = builtinAstBuilder->create<Scope>();
slangLanguageScope->nextSibling = hlslLanguageScope;
@@ -290,6 +293,7 @@ SlangResult Session::compileStdLib(slang::CompileStdLibFlags compileFlags)
// TODO(JS): Could make this return a SlangResult as opposed to exception
addBuiltinSource(coreLanguageScope, "core", getCoreLibraryCode());
addBuiltinSource(hlslLanguageScope, "hlsl", getHLSLLibraryCode());
+ addBuiltinSource(autodiffLanguageScope, "diff", getAutodiffLibraryCode());
if (compileFlags & slang::CompileStdLibFlag::WriteDocumentation)
{
@@ -348,6 +352,7 @@ SlangResult Session::loadStdLib(const void* stdLib, size_t stdLibSizeInBytes)
// Let's try loading serialized modules and adding them
SLANG_RETURN_ON_FAIL(_readBuiltinModule(fileSystem, coreLanguageScope, "core"));
SLANG_RETURN_ON_FAIL(_readBuiltinModule(fileSystem, hlslLanguageScope, "hlsl"));
+ SLANG_RETURN_ON_FAIL(_readBuiltinModule(fileSystem, autodiffLanguageScope, "diff"));
return SLANG_OK;
}