summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/diff.meta.slang3
-rw-r--r--source/slang/slang-ast-expr.h8
-rw-r--r--source/slang/slang-ast-modifier.h6
-rw-r--r--source/slang/slang-check-expr.cpp59
-rw-r--r--source/slang/slang-check-impl.h2
-rw-r--r--source/slang/slang-check-overload.cpp55
-rw-r--r--source/slang/slang-emit.cpp6
-rw-r--r--source/slang/slang-ir-diff-jvp.cpp824
-rw-r--r--source/slang/slang-ir-diff-jvp.h2
-rw-r--r--source/slang/slang-ir-inst-defs.h9
-rw-r--r--source/slang/slang-ir-insts.h37
-rw-r--r--source/slang/slang-ir.cpp22
-rw-r--r--source/slang/slang-lower-to-ir.cpp19
-rw-r--r--source/slang/slang-parser.cpp23
-rw-r--r--source/slang/slang.natvis1
-rw-r--r--tests/autodiff/backward-diff-smoke.slang25
-rw-r--r--tests/autodiff/backward-diff-smoke.slang.expected.txt6
17 files changed, 1040 insertions, 67 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang
index 1f6064983..6f1008277 100644
--- a/source/slang/diff.meta.slang
+++ b/source/slang/diff.meta.slang
@@ -9,6 +9,9 @@ attribute_syntax [ForwardDifferentiable] : ForwardDifferentiableAttribute;
__attributeTarget(FunctionDeclBase)
attribute_syntax [ForwardDerivative(function)] : ForwardDerivativeAttribute;
+__attributeTarget(FunctionDeclBase)
+attribute_syntax [BackwardDifferentiable] : BackwardDifferentiableAttribute;
+
__attributeTarget(FunctionDeclBase)
attribute_syntax [ForwardDerivativeOf(function)] : ForwardDerivativeOfAttribute;
diff --git a/source/slang/slang-ast-expr.h b/source/slang/slang-ast-expr.h
index ef6a05c71..86b72e05a 100644
--- a/source/slang/slang-ast-expr.h
+++ b/source/slang/slang-ast-expr.h
@@ -451,6 +451,14 @@ class ForwardDifferentiateExpr: public HigherOrderInvokeExpr
SLANG_AST_CLASS(ForwardDifferentiateExpr)
};
+ /// An expression of the form `__bwd_diff(fn)` to access the
+ /// forward-mode derivative version of the function `fn`
+ ///
+class BackwardDifferentiateExpr: public HigherOrderInvokeExpr
+{
+ SLANG_AST_CLASS(BackwardDifferentiateExpr)
+};
+
/// A type expression of the form `__TaggedUnion(A, ...)`.
///
/// An expression of this form will resolve to a `TaggedUnionType`
diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h
index d6a961328..8419facce 100644
--- a/source/slang/slang-ast-modifier.h
+++ b/source/slang/slang-ast-modifier.h
@@ -1045,6 +1045,12 @@ class ForwardDerivativeOfAttribute : public Attribute
Expr* backDeclRef; // DeclRef to this derivative function when initiated from primalFunction.
};
+ /// The `[BackwardDifferentiable]` attribute indicates that a function can be backward-differentiated.
+class BackwardDifferentiableAttribute : public DifferentiableAttribute
+{
+ SLANG_AST_CLASS(BackwardDifferentiableAttribute)
+};
+
/// Indicates that the modified declaration is one of the "magic" declarations
/// that NVAPI uses to communicate extended operations. When NVAPI is being included
/// via the prelude for downstream compilation, declarations with this modifier
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 30db9ecfa..f568dd8df 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -969,6 +969,14 @@ namespace Slang
maybeRegisterDifferentiableType(getASTBuilder(), checkedTerm->type.type);
}
+ // Differentiable type checking.
+ // TODO: This can be super slow.
+ if (this->m_parentFunc &&
+ this->m_parentFunc->findModifier<BackwardDifferentiableAttribute>())
+ {
+ maybeRegisterDifferentiableType(getASTBuilder(), checkedTerm->type.type);
+ }
+
return checkedTerm;
}
@@ -2027,6 +2035,45 @@ namespace Slang
return jvpType;
}
+ Type* SemanticsVisitor::processBackwardDiffFuncType(FuncType* originalType)
+ {
+ // Resolve backward diff type here.
+ // Note that this type checking needs to be in sync with
+ // the auto-generation logic in slang-ir-jvp-diff.cpp
+
+ FuncType* type = m_astBuilder->create<FuncType>();
+
+ // The backward diff return type is void
+ //
+ type->resultType = m_astBuilder->getVoidType();
+
+ // No support for differentiating function that throw errors, for now.
+ SLANG_ASSERT(originalType->errorType->equals(m_astBuilder->getBottomType()));
+ type->errorType = originalType->errorType;
+
+ for (UInt i = 0; i < originalType->getParamCount(); i++)
+ {
+ if (auto derivType = _toDifferentialParamType(originalType->getParamType(i)))
+ {
+ // Using inout type on all the derivative parameters
+ if (auto outType = as<OutType>(derivType))
+ {
+ derivType = outType->getValueType();
+ }
+ else if (!as<PtrTypeBase>(derivType))
+ {
+ derivType = m_astBuilder->getInOutType(derivType);
+ }
+ type->paramTypes.add(derivType);
+ }
+ }
+
+ // Last parameter is the initial derivative of the original return type
+ type->paramTypes.add(originalType->resultType);
+
+ return type;
+ }
+
Expr* SemanticsExprVisitor::visitForwardDifferentiateExpr(ForwardDifferentiateExpr* expr)
{
// Check/Resolve inner function declaration.
@@ -2039,6 +2086,18 @@ namespace Slang
return expr;
}
+ Expr* SemanticsExprVisitor::visitBackwardDifferentiateExpr(BackwardDifferentiateExpr* expr)
+ {
+ // Check/Resolve inner function declaration.
+ expr->baseFunction = CheckTerm(expr->baseFunction);
+
+ // For now we only support using higher order expr as callee in an invoke expr.
+ // The actual type of the higher order function will be derived during resolve invoke.
+ expr->type = m_astBuilder->getBottomType();
+
+ return expr;
+ }
+
Expr* SemanticsExprVisitor::visitGetArrayLengthExpr(GetArrayLengthExpr* expr)
{
expr->arrayExpr = CheckTerm(expr->arrayExpr);
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index 70b120518..e7681212f 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -714,6 +714,7 @@ namespace Slang
// Convert a function's original type to it's JVP type.
Type* processJVPFuncType(FuncType* originalType);
+ Type* processBackwardDiffFuncType(FuncType* originalType);
/// Registers a type as conforming to IDifferentiable, along with a witness
/// describing the relationship.
@@ -1908,6 +1909,7 @@ namespace Slang
Expr* visitModifiedTypeExpr(ModifiedTypeExpr* expr);
Expr* visitForwardDifferentiateExpr(ForwardDifferentiateExpr* expr);
+ Expr* visitBackwardDifferentiateExpr(BackwardDifferentiateExpr* expr);
Expr* visitGetArrayLengthExpr(GetArrayLengthExpr* expr);
diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp
index 38754d170..fe9de9433 100644
--- a/source/slang/slang-check-overload.cpp
+++ b/source/slang/slang-check-overload.cpp
@@ -1548,35 +1548,51 @@ namespace Slang
// Lookup the higher order function and process types accordingly. In the future,
// if there are enough varieties, we can have dispatch logic instead of an
// if-else ladder.
- if (auto jvpExpr = as<ForwardDifferentiateExpr>(funcExpr))
+ if (auto expr = as<HigherOrderInvokeExpr>(funcExpr))
{
- if (auto origFuncType = as<FuncType>(jvpExpr->baseFunction->type))
+ if (auto origFuncType = as<FuncType>(expr->baseFunction->type))
{
- // Case: __fwd_diff(name-resolved-to-decl-ref)
- auto baseFuncDeclRef = as<DeclRefExpr>(jvpExpr->baseFunction)->declRef.as<CallableDecl>();
+ auto baseFuncDeclRef = as<DeclRefExpr>(expr->baseFunction)->declRef.as<CallableDecl>();
SLANG_ASSERT(baseFuncDeclRef);
OverloadCandidate candidate;
candidate.flavor = OverloadCandidate::Flavor::Expr;
- candidate.funcType = as<FuncType>(processJVPFuncType(origFuncType));
+ if (auto fwdExpr = as<ForwardDifferentiateExpr>(expr))
+ {
+ // Case: __fwd_diff(name-resolved-to-decl-ref)
+ candidate.funcType = as<FuncType>(processJVPFuncType(origFuncType));
+ }
+ else if (auto bwdExpr = as<BackwardDifferentiateExpr>(expr))
+ {
+ // Case: __bwd_diff(name-resolved-to-decl-ref)
+ candidate.funcType = as<FuncType>(processBackwardDiffFuncType(origFuncType));
+ }
candidate.resultType = candidate.funcType->getResultType();
candidate.item = LookupResultItem(baseFuncDeclRef);
AddOverloadCandidate(context, candidate);
}
- else if (auto origOverloadedType = as<OverloadGroupType>(jvpExpr->baseFunction->type))
+ else if (auto origOverloadedType = as<OverloadGroupType>(expr->baseFunction->type))
{
- // Case: __fwd_diff(name-resolved-to-multiple-decl-ref)
- if (auto overloadExpr = as<OverloadedExpr>(jvpExpr->baseFunction))
+ if (auto overloadExpr = as<OverloadedExpr>(expr->baseFunction))
{
for (auto item : overloadExpr->lookupResult2.items)
{
auto funcType = as<FuncType>(GetTypeForDeclRef(item.declRef, item.declRef.decl->loc));
if (!funcType)
continue;
- funcType = as<FuncType>(processJVPFuncType(funcType));
+ if (auto fwdExpr = as<ForwardDifferentiateExpr>(expr))
+ {
+ // Case: __fwd_diff(name-resolved-to-decl-ref)
+ funcType = as<FuncType>(processJVPFuncType(funcType));
+ }
+ else if (auto bwdExpr = as<BackwardDifferentiateExpr>(expr))
+ {
+ // Case: __bwd_diff(name-resolved-to-decl-ref)
+ funcType = as<FuncType>(processBackwardDiffFuncType(funcType));
+ }
if (!funcType)
continue;
OverloadCandidate candidate;
@@ -1597,9 +1613,8 @@ namespace Slang
funcExpr->type);
}
}
- else if (auto baseFuncGenericDeclRef = as<DeclRefExpr>(jvpExpr->baseFunction)->declRef.as<GenericDecl>())
+ else if (auto baseFuncGenericDeclRef = as<DeclRefExpr>(expr->baseFunction)->declRef.as<GenericDecl>())
{
- // Case: __fwd_diff(name-resolved-to-generic-decl)
// Get inner function
DeclRef<Decl> unspecializedInnerRef = DeclRef<Decl>(
@@ -1610,7 +1625,9 @@ namespace Slang
auto funcType = getFuncType(this->getASTBuilder(), unspecializedInnerRef.as<CallableDecl>());
// Process func type to generate JVP func type.
- auto jvpFuncType = as<FuncType>(processJVPFuncType(funcType));
+ auto jvpFuncType = as<ForwardDifferentiateExpr>(expr) ?
+ as<FuncType>(processJVPFuncType(funcType)) :
+ as<FuncType>(processBackwardDiffFuncType(funcType));
// Extract parameter list from processed type.
List<Type*> paramTypes;
@@ -1634,8 +1651,18 @@ namespace Slang
// in order to process the specialized version of the original func type.
// This could potentially be a declRef.substitute(jvpFuncType)
//
- candidate.funcType = as<FuncType>(processJVPFuncType(
- getFuncType(this->getASTBuilder(), innerRef.as<CallableDecl>())));
+ if (auto fwdExpr = as<ForwardDifferentiateExpr>(expr))
+ {
+ // Case: __fwd_diff(name-resolved-to-generic-decl)
+ candidate.funcType = as<FuncType>(processJVPFuncType(
+ getFuncType(this->getASTBuilder(), innerRef.as<CallableDecl>())));
+ }
+ else if (auto bwdExpr = as<BackwardDifferentiateExpr>(expr))
+ {
+ // Case: __bwd_diff(name-resolved-to-generic-decl)
+ candidate.funcType = as<FuncType>(processBackwardDiffFuncType(
+ getFuncType(this->getASTBuilder(), innerRef.as<CallableDecl>())));
+ }
candidate.resultType = candidate.funcType->getResultType();
candidate.item = LookupResultItem(innerRef);
diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp
index de9d23e97..2478eccc2 100644
--- a/source/slang/slang-emit.cpp
+++ b/source/slang/slang-emit.cpp
@@ -376,11 +376,7 @@ Result linkAndOptimizeIR(
dumpIRIfEnabled(codeGenContext, irModule, "BEFORE-AUTODIFF");
// Process higher-order calles to auto-diff passes.
- // 1. Generate JVP code wherever necessary. (Linearization or "forward-mode" pass)
- processForwardDifferentiableFuncs(irModule, sink);
-
- // 2. Transpose JVP to VJP code wherever needed. (Transposition or "reverse-mode" pass)
- // processVJPDerivativeMarkers(module); // Disabled currently. No impl yet.
+ processDifferentiableFuncs(irModule, sink);
stripAutoDiffDecorations(irModule);
diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp
index 4c7a132d0..7f5979a87 100644
--- a/source/slang/slang-ir-diff-jvp.cpp
+++ b/source/slang/slang-ir-diff-jvp.cpp
@@ -2017,6 +2017,664 @@ struct JVPTranscriber
}
};
+
+struct BackwardDiffTranscriber
+{
+
+ // Stores the mapping of arbitrary 'R-value' instructions to instructions that represent
+ // their differential values.
+ Dictionary<IRInst*, IRInst*> orginalToTranscribed;
+
+ // Set of insts currently being transcribed. Used to avoid infinite loops.
+ HashSet<IRInst*> instsInProgress;
+
+ // Cloning environment to hold mapping from old to new copies for the primal
+ // instructions.
+ IRCloneEnv cloneEnv;
+
+ // Diagnostic sink for error messages.
+ DiagnosticSink* sink;
+
+ // Type conformance information.
+ AutoDiffSharedContext* autoDiffSharedContext;
+
+ // Builder to help with creating and accessing the 'DifferentiablePair<T>' struct
+ DifferentialPairTypeBuilder* pairBuilder;
+
+ DifferentiableTypeConformanceContext differentiableTypeConformanceContext;
+
+ List<InstPair> followUpFunctionsToTranscribe;
+
+ // Map that stores the upper gradient given an IRInst*
+ Dictionary<IRInst*, List<IRInst*>> upperGradients;
+ Dictionary<IRInst*, IRInst*> primalToDiffPair;
+
+ SharedIRBuilder* sharedBuilder;
+ // Witness table that `DifferentialBottom:IDifferential`.
+ IRWitnessTable* differentialBottomWitness = nullptr;
+ Dictionary<InstPair, IRInst*> differentialPairTypes;
+
+ BackwardDiffTranscriber(AutoDiffSharedContext* shared, SharedIRBuilder* inSharedBuilder, DiagnosticSink* inSink)
+ : autoDiffSharedContext(shared)
+ , sink(inSink)
+ , differentiableTypeConformanceContext(shared)
+ , sharedBuilder(inSharedBuilder)
+ {}
+
+ DiagnosticSink* getSink()
+ {
+ SLANG_ASSERT(sink);
+ return sink;
+ }
+
+ IRFuncType* differentiateFunctionType(IRBuilder* builder, IRFuncType* funcType)
+ {
+ List<IRType*> newParameterTypes;
+ IRType* diffReturnType;
+
+ for (UIndex i = 0; i < funcType->getParamCount(); i++)
+ {
+ auto origType = funcType->getParamType(i);
+ if (auto diffPairType = tryGetDiffPairType(builder, origType))
+ {
+ auto inoutDiffPairType = builder->getPtrType(kIROp_InOutType, diffPairType);
+ newParameterTypes.add(inoutDiffPairType);
+ }
+ else
+ newParameterTypes.add(origType);
+ }
+
+ newParameterTypes.add(funcType->getResultType());
+
+ diffReturnType = builder->getVoidType();
+
+ return builder->getFuncType(newParameterTypes, diffReturnType);
+ }
+
+ IRWitnessTable* getDifferentialBottomWitness()
+ {
+ IRBuilder builder(sharedBuilder);
+ builder.setInsertInto(sharedBuilder->getModule()->getModuleInst());
+ auto result =
+ as<IRWitnessTable>(differentiableTypeConformanceContext.lookUpConformanceForType(
+ builder.getDifferentialBottomType()));
+ SLANG_ASSERT(result);
+ return result;
+ }
+
+ // Get or construct `:IDifferentiable` conformance for a DifferentiablePair.
+ IRWitnessTable* getDifferentialPairWitness(IRInst* inDiffPairType)
+ {
+ IRBuilder builder(sharedBuilder);
+ builder.setInsertInto(inDiffPairType->parent);
+ auto diffPairType = as<IRDifferentialPairType>(inDiffPairType);
+ SLANG_ASSERT(diffPairType);
+ auto result =
+ as<IRWitnessTable>(differentiableTypeConformanceContext.lookUpConformanceForType(
+ builder.getDifferentialBottomType()));
+ if (result)
+ return result;
+
+ auto table = builder.createWitnessTable(autoDiffSharedContext->differentiableInterfaceType, diffPairType);
+ auto diffType = differentiateType(&builder, diffPairType->getValueType());
+ auto differentialType = builder.getDifferentialPairType(diffType, getDifferentialBottomWitness());
+ builder.createWitnessTableEntry(table, autoDiffSharedContext->differentialAssocTypeStructKey, differentialType);
+ // Omit the method synthesis here, since we can just intercept those directly at `getXXMethodForType`.
+
+ differentiableTypeConformanceContext.differentiableWitnessDictionary[diffPairType] = table;
+ return table;
+ }
+
+ IRType* getOrCreateDiffPairType(IRInst* primalType, IRInst* witness)
+ {
+ IRBuilder builder(sharedBuilder);
+ builder.setInsertInto(primalType->parent);
+ return builder.getDifferentialPairType(
+ (IRType*)primalType,
+ witness);
+ }
+
+ IRType* getOrCreateDiffPairType(IRInst* primalType)
+ {
+ IRBuilder builder(sharedBuilder);
+ builder.setInsertInto(primalType->parent);
+ auto witness = as<IRWitnessTable>(
+ differentiableTypeConformanceContext.lookUpConformanceForType((IRType*)primalType));
+ if (!witness)
+ witness = getDifferentialBottomWitness();
+ return builder.getDifferentialPairType(
+ (IRType*)primalType,
+ witness);
+ }
+
+ IRType* differentiateType(IRBuilder* builder, IRType* origType)
+ {
+ IRInst* diffType = nullptr;
+ if (!orginalToTranscribed.TryGetValue(origType, diffType))
+ {
+ diffType = _differentiateTypeImpl(builder, origType);
+ orginalToTranscribed[origType] = diffType;
+ }
+ return (IRType*)diffType;
+ }
+
+ IRType* _differentiateTypeImpl(IRBuilder* builder, IRType* origType)
+ {
+ if (auto ptrType = as<IRPtrTypeBase>(origType))
+ return builder->getPtrType(
+ origType->getOp(),
+ differentiateType(builder, ptrType->getValueType()));
+
+ // If there is an explicit primal version of this type in the local scope, load that
+ // otherwise use the original type.
+ //
+ IRInst* primalType = origType;
+
+ // Special case certain compound types (PtrType, FuncType, etc..)
+ // otherwise try to lookup a differential definition for the given type.
+ // If one does not exist, then we assume it's not differentiable.
+ //
+ switch (primalType->getOp())
+ {
+ case kIROp_Param:
+ if (as<IRTypeType>(primalType->getDataType()))
+ return (IRType*)(differentiableTypeConformanceContext.getDifferentialForType(
+ builder,
+ (IRType*)primalType));
+ else if (as<IRWitnessTableType>(primalType->getDataType()))
+ return (IRType*)primalType;
+
+ case kIROp_ArrayType:
+ {
+ auto primalArrayType = as<IRArrayType>(primalType);
+ if (auto diffElementType = differentiateType(builder, primalArrayType->getElementType()))
+ return builder->getArrayType(
+ diffElementType,
+ primalArrayType->getElementCount());
+ else
+ return nullptr;
+ }
+
+ case kIROp_DifferentialPairType:
+ {
+ auto primalPairType = as<IRDifferentialPairType>(primalType);
+ return getOrCreateDiffPairType(
+ pairBuilder->getDiffTypeFromPairType(builder, primalPairType),
+ pairBuilder->getDiffTypeWitnessFromPairType(builder, primalPairType));
+ }
+
+ case kIROp_FuncType:
+ return differentiateFunctionType(builder, as<IRFuncType>(primalType));
+
+ case kIROp_OutType:
+ if (auto diffValueType = differentiateType(builder, as<IROutType>(primalType)->getValueType()))
+ return builder->getOutType(diffValueType);
+ else
+ return nullptr;
+
+ case kIROp_InOutType:
+ if (auto diffValueType = differentiateType(builder, as<IRInOutType>(primalType)->getValueType()))
+ return builder->getInOutType(diffValueType);
+ else
+ return nullptr;
+
+ case kIROp_TupleType:
+ {
+ auto tupleType = as<IRTupleType>(primalType);
+ List<IRType*> diffTypeList;
+ // TODO: what if we have type parameters here?
+ for (UIndex ii = 0; ii < tupleType->getOperandCount(); ii++)
+ diffTypeList.add(
+ differentiateType(builder, (IRType*)tupleType->getOperand(ii)));
+
+ return builder->getTupleType(diffTypeList);
+ }
+
+ default:
+ return (IRType*)(differentiableTypeConformanceContext.getDifferentialForType(builder, (IRType*)primalType));
+ }
+ }
+
+ IRType* tryGetDiffPairType(IRBuilder* builder, IRType* primalType)
+ {
+ // If this is a PtrType (out, inout, etc..), then create diff pair from
+ // value type and re-apply the appropropriate PtrType wrapper.
+ //
+ if (auto origPtrType = as<IRPtrTypeBase>(primalType))
+ {
+ if (auto diffPairValueType = tryGetDiffPairType(builder, origPtrType->getValueType()))
+ return builder->getPtrType(primalType->getOp(), diffPairValueType);
+ else
+ return nullptr;
+ }
+ auto diffType = differentiateType(builder, primalType);
+ if (diffType)
+ return (IRType*)getOrCreateDiffPairType(primalType);
+ return nullptr;
+ }
+
+ InstPair transcribeParam(IRBuilder* builder, IRParam* origParam)
+ {
+ auto primalDataType = origParam->getDataType();
+ // Do not differentiate generic type (and witness table) parameters
+ if (as<IRTypeType>(primalDataType) || as<IRWitnessTableType>(primalDataType))
+ {
+ return InstPair(
+ cloneInst(&cloneEnv, builder, origParam),
+ nullptr);
+ }
+
+ if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)primalDataType))
+ {
+ IRInst* diffPairParam = builder->emitParam(diffPairType);
+
+ auto diffPairVarName = makeDiffPairName(origParam);
+ if (diffPairVarName.getLength() > 0)
+ builder->addNameHintDecoration(diffPairParam, diffPairVarName.getUnownedSlice());
+
+ SLANG_ASSERT(diffPairParam);
+
+ if (auto pairType = as<IRDifferentialPairType>(diffPairParam->getDataType()))
+ {
+ return InstPair(
+ builder->emitDifferentialPairGetPrimal(diffPairParam),
+ builder->emitDifferentialPairGetDifferential(
+ (IRType*)pairBuilder->getDiffTypeFromPairType(builder, pairType),
+ diffPairParam));
+ }
+ // If this is an `in/inout DifferentialPair<>` parameter, we can't produce
+ // its primal and diff parts right now because they would represent a reference
+ // to a pair field, which doesn't make sense since pair types are considered mutable.
+ // We encode the result as if the param is non-differentiable, and handle it
+ // with special care at load/store.
+ return InstPair(diffPairParam, nullptr);
+ }
+
+
+ return InstPair(
+ cloneInst(&cloneEnv, builder, origParam),
+ nullptr);
+ }
+
+ // Returns "dp<var-name>" to use as a name hint for parameters.
+ // If no primal name is available, returns a blank string.
+ //
+ String makeDiffPairName(IRInst* origVar)
+ {
+ if (auto namehintDecoration = origVar->findDecoration<IRNameHintDecoration>())
+ {
+ return ("dp" + String(namehintDecoration->getName()));
+ }
+
+ return String("");
+ }
+
+
+ // In differential computation, the 'default' differential value is always zero.
+ // This is a consequence of differential computing being inherently linear. As a
+ // result, it's useful to have a method to generate zero literals of any (arithmetic) type.
+ // The current implementation requires that types are defined linearly.
+ //
+ IRInst* getDifferentialZeroOfType(IRBuilder* builder, IRType* primalType)
+ {
+ if (auto diffType = differentiateType(builder, primalType))
+ {
+ switch (diffType->getOp())
+ {
+ case kIROp_DifferentialPairType:
+ return builder->emitMakeDifferentialPair(
+ diffType,
+ getDifferentialZeroOfType(builder, as<IRDifferentialPairType>(diffType)->getValueType()),
+ getDifferentialZeroOfType(builder, as<IRDifferentialPairType>(diffType)->getValueType()));
+ }
+ // Since primalType has a corresponding differential type, we can lookup the
+ // definition for zero().
+ auto zeroMethod = differentiableTypeConformanceContext.getZeroMethodForType(builder, primalType);
+ SLANG_ASSERT(zeroMethod);
+
+ auto emptyArgList = List<IRInst*>();
+ return builder->emitCallInst((IRType*)diffType, zeroMethod, emptyArgList);
+ }
+ else
+ {
+ if (isScalarIntegerType(primalType))
+ {
+ return builder->getIntValue(primalType, 0);
+ }
+
+ getSink()->diagnose(primalType->sourceLoc,
+ Diagnostics::internalCompilerError,
+ "could not generate zero value for given type");
+ return nullptr;
+ }
+ }
+
+ InstPair transcribeBlock(IRBuilder* builder, IRBlock* origBlock)
+ {
+ IRBuilder subBuilder(builder->getSharedBuilder());
+ subBuilder.setInsertLoc(builder->getInsertLoc());
+
+ IRBlock* diffBlock = subBuilder.emitBlock();
+
+ subBuilder.setInsertInto(diffBlock);
+
+ // First transcribe every parameter in the block.
+ for (auto param = origBlock->getFirstParam(); param; param = param->getNextParam())
+ this->copyParam(&subBuilder, param);
+
+ // The extra param for input gradient
+ auto gradParam = subBuilder.emitParam(as<IRFuncType>(origBlock->getParent()->getFullType())->getResultType());
+
+ // Then, run through every instruction and use the transcriber to generate the appropriate
+ // derivative code.
+ //
+ for (auto child = origBlock->getFirstOrdinaryInst(); child; child = child->getNextInst())
+ this->copyInst(&subBuilder, child);
+
+ auto lastInst = diffBlock->getLastOrdinaryInst();
+ List<IRInst*> grads = { gradParam };
+ upperGradients.Add(lastInst, grads);
+ for (auto child = diffBlock->getLastOrdinaryInst(); child; child = child->getPrevInst())
+ {
+ auto upperGrads = upperGradients.TryGetValue(child);
+ if (!upperGrads)
+ continue;
+ if (upperGrads->getCount() > 1)
+ {
+ auto sumGrad = upperGrads->getFirst();
+ for (auto i = 1; i < upperGrads->getCount(); i++)
+ {
+ sumGrad = subBuilder.emitAdd(sumGrad->getDataType(), sumGrad, (*upperGrads)[i]);
+ }
+ this->transcribeInstBackward(&subBuilder, child, sumGrad);
+ }
+ else
+ this->transcribeInstBackward(&subBuilder, child, upperGrads->getFirst());
+ }
+
+ subBuilder.emitReturn();
+
+ return InstPair(diffBlock, diffBlock);
+ }
+
+ // Create an empty func to represent the transcribed func of `origFunc`.
+ InstPair transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc)
+ {
+ IRBuilder builder(inBuilder->getSharedBuilder());
+ builder.setInsertBefore(origFunc);
+
+ IRFunc* primalFunc = origFunc;
+
+ differentiableTypeConformanceContext.setFunc(origFunc);
+
+ primalFunc = origFunc;
+
+ auto diffFunc = builder.createFunc();
+
+ SLANG_ASSERT(as<IRFuncType>(origFunc->getFullType()));
+ IRType* diffFuncType = this->differentiateFunctionType(
+ &builder,
+ as<IRFuncType>(origFunc->getFullType()));
+ diffFunc->setFullType(diffFuncType);
+
+ if (auto nameHint = origFunc->findDecoration<IRNameHintDecoration>())
+ {
+ auto originalName = nameHint->getName();
+ StringBuilder newNameSb;
+ newNameSb << "s_bwd_" << originalName;
+ builder.addNameHintDecoration(diffFunc, newNameSb.getUnownedSlice());
+ }
+ builder.addBackwardDerivativeDecoration(origFunc, diffFunc);
+
+ // Mark the generated derivative function itself as differentiable.
+ builder.addBackwardDifferentiableDecoration(diffFunc);
+
+ // Find and clone `DifferentiableTypeDictionaryDecoration` to the new diffFunc.
+ if (auto dictDecor = origFunc->findDecoration<IRDifferentiableTypeDictionaryDecoration>())
+ {
+ cloneDecoration(dictDecor, diffFunc);
+ }
+
+ auto result = InstPair(primalFunc, diffFunc);
+ followUpFunctionsToTranscribe.add(result);
+ return result;
+ }
+
+ // Transcribe a function definition.
+ InstPair transcribeFunc(IRBuilder* inBuilder, IRFunc* primalFunc, IRFunc* diffFunc)
+ {
+ IRBuilder builder(inBuilder->getSharedBuilder());
+ builder.setInsertInto(diffFunc);
+
+ differentiableTypeConformanceContext.setFunc(primalFunc);
+ // Transcribe children from origFunc into diffFunc
+ for (auto block = primalFunc->getFirstBlock(); block; block = block->getNextBlock())
+ this->transcribeBlock(&builder, block);
+
+ return InstPair(primalFunc, diffFunc);
+ }
+
+ IRInst* copyParam(IRBuilder* builder, IRParam* origParam)
+ {
+ auto primalDataType = origParam->getDataType();
+
+ if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)primalDataType))
+ {
+ auto inoutDiffPairType = builder->getPtrType(kIROp_InOutType, diffPairType);
+ IRInst* diffParam = builder->emitParam(inoutDiffPairType);
+
+ auto diffPairVarName = makeDiffPairName(origParam);
+ if (diffPairVarName.getLength() > 0)
+ builder->addNameHintDecoration(diffParam, diffPairVarName.getUnownedSlice());
+
+ SLANG_ASSERT(diffParam);
+ auto paramValue = builder->emitLoad(diffParam);
+ auto primal = builder->emitDifferentialPairGetPrimal(paramValue);
+ orginalToTranscribed.Add(origParam, primal);
+ primalToDiffPair.Add(primal, diffParam);
+
+ return diffParam;
+ }
+
+
+ return cloneInst(&cloneEnv, builder, origParam);
+ }
+
+ InstPair copyBinaryArith(IRBuilder* builder, IRInst* origArith)
+ {
+ SLANG_ASSERT(origArith->getOperandCount() == 2);
+
+ auto origLeft = origArith->getOperand(0);
+ auto origRight = origArith->getOperand(1);
+
+ IRInst* primalLeft;
+ if (!orginalToTranscribed.TryGetValue(origLeft, primalLeft))
+ {
+ primalLeft = origLeft;
+ }
+ IRInst* primalRight;
+ if (!orginalToTranscribed.TryGetValue(origRight, primalRight))
+ {
+ primalRight = origRight;
+ }
+
+ auto resultType = origArith->getDataType();
+ IRInst* newInst;
+ switch (origArith->getOp())
+ {
+ case kIROp_Add:
+ newInst = builder->emitAdd(resultType, primalLeft, primalRight);
+ break;
+ case kIROp_Mul:
+ newInst = builder->emitMul(resultType, primalLeft, primalRight);
+ break;
+ case kIROp_Sub:
+ newInst = builder->emitSub(resultType, primalLeft, primalRight);
+ break;
+ case kIROp_Div:
+ newInst = builder->emitDiv(resultType, primalLeft, primalRight);
+ break;
+ default:
+ getSink()->diagnose(origArith->sourceLoc,
+ Diagnostics::unimplemented,
+ "this arithmetic instruction cannot be differentiated");
+ }
+ orginalToTranscribed.Add(origArith, newInst);
+ return InstPair(newInst, nullptr);
+ }
+
+ IRInst* transcribeBinaryArithBackward(IRBuilder* builder, IRInst* origArith, IRInst* grad)
+ {
+ SLANG_ASSERT(origArith->getOperandCount() == 2);
+
+ auto lhs = origArith->getOperand(0);
+ auto rhs = origArith->getOperand(1);
+
+ if (as<IRInOutType>(lhs->getDataType()))
+ {
+ lhs = builder->emitLoad(lhs);
+ lhs = builder->emitDifferentialPairGetPrimal(lhs);
+ }
+ if (as<IRInOutType>(rhs->getDataType()))
+ {
+ rhs = builder->emitLoad(rhs);
+ rhs = builder->emitDifferentialPairGetPrimal(rhs);
+ }
+
+ IRInst* leftGrad;
+ IRInst* rightGrad;
+
+
+ switch (origArith->getOp())
+ {
+ case kIROp_Add:
+ leftGrad = grad;
+ rightGrad = grad;
+ break;
+ case kIROp_Mul:
+ leftGrad = builder->emitMul(grad->getDataType(), rhs, grad);
+ rightGrad = builder->emitMul(grad->getDataType(), lhs, grad);
+ break;
+ case kIROp_Sub:
+ leftGrad = grad;
+ rightGrad = builder->emitNeg(grad->getDataType(), grad);
+ break;
+ case kIROp_Div:
+ leftGrad = builder->emitMul(grad->getDataType(), rhs, grad);
+ rightGrad = builder->emitMul(grad->getDataType(), lhs, grad); // TODO 1.0 / Grad
+ break;
+ default:
+ getSink()->diagnose(origArith->sourceLoc,
+ Diagnostics::unimplemented,
+ "this arithmetic instruction cannot be differentiated");
+ }
+
+ lhs = origArith->getOperand(0);
+ rhs = origArith->getOperand(1);
+ if (auto leftGrads = upperGradients.TryGetValue(lhs))
+ {
+ leftGrads->add(leftGrad);
+ }
+ else
+ {
+ upperGradients.Add(lhs, leftGrad);
+ }
+ if (auto rightGrads = upperGradients.TryGetValue(rhs))
+ {
+ rightGrads->add(rightGrad);
+ }
+ else
+ {
+ upperGradients.Add(rhs, rightGrad);
+ }
+
+ return nullptr;
+ }
+
+ InstPair copyInst(IRBuilder* builder, IRInst* origInst)
+ {
+ // Handle common SSA-style operations
+ switch (origInst->getOp())
+ {
+ case kIROp_Param:
+ return transcribeParam(builder, as<IRParam>(origInst));
+
+ case kIROp_Return:
+ return InstPair(nullptr, nullptr);
+
+ case kIROp_Add:
+ case kIROp_Mul:
+ case kIROp_Sub:
+ case kIROp_Div:
+ return copyBinaryArith(builder, origInst);
+
+ default:
+ // Not yet implemented
+ SLANG_ASSERT(0);
+ }
+
+ return InstPair(nullptr, nullptr);
+ }
+
+ IRInst* transcribeParamBackward(IRBuilder* builder, IRInst* param, IRInst* grad)
+ {
+ IRInOutType* inoutParam = as<IRInOutType>(param->getDataType());
+ auto pairType = as<IRDifferentialPairType>(inoutParam->getValueType());
+ auto paramValue = builder->emitLoad(param);
+ auto primal = builder->emitDifferentialPairGetPrimal(paramValue);
+ auto diff = builder->emitDifferentialPairGetDifferential(
+ (IRType*)pairBuilder->getDiffTypeFromPairType(builder, pairType),
+ paramValue
+ );
+ auto newDiff = builder->emitAdd(grad->getDataType(), diff, grad);
+ auto updatedParam = builder->emitMakeDifferentialPair(pairType, primal, newDiff);
+ auto store = builder->emitStore(param, updatedParam);
+
+ return store;
+ }
+
+ IRInst* transcribeInstBackward(IRBuilder* builder, IRInst* origInst, IRInst* grad)
+ {
+ // Handle common SSA-style operations
+ switch (origInst->getOp())
+ {
+ case kIROp_Param:
+ return transcribeParamBackward(builder, as<IRParam>(origInst), grad);
+
+ case kIROp_Add:
+ case kIROp_Mul:
+ case kIROp_Sub:
+ case kIROp_Div:
+ return transcribeBinaryArithBackward(builder, origInst, grad);
+
+ case kIROp_DifferentialPairGetPrimal:
+ {
+ if (auto param = primalToDiffPair.TryGetValue(origInst))
+ {
+ if (auto leftGrads = upperGradients.TryGetValue(*param))
+ {
+ leftGrads->add(grad);
+ }
+ else
+ {
+ upperGradients.Add(*param, grad);
+ }
+ }
+ else
+ SLANG_ASSERT(0);
+ return nullptr;
+ }
+
+ default:
+ // Not yet implemented
+ SLANG_ASSERT(0);
+ }
+
+ return nullptr;
+ }
+};
+
+
struct JVPDerivativeContext : public InstPassBase
{
@@ -2034,7 +2692,7 @@ struct JVPDerivativeContext : public InstPassBase
SharedIRBuilder* sharedBuilder = &sharedBuilderStorage;
sharedBuilder->init(module);
sharedBuilder->deduplicateAndRebuildGlobalNumberingMap();
-
+
IRBuilder builderStorage(sharedBuilderStorage);
IRBuilder* builder = &builderStorage;
@@ -2059,7 +2717,7 @@ struct JVPDerivativeContext : public InstPassBase
IRInst* lookupJVPReference(IRInst* primalFunction)
{
- if(auto jvpDefinition = primalFunction->findDecoration<IRForwardDerivativeDecoration>())
+ if (auto jvpDefinition = primalFunction->findDecoration<IRForwardDerivativeDecoration>())
return jvpDefinition->getForwardDerivativeFunc();
return nullptr;
}
@@ -2069,16 +2727,24 @@ struct JVPDerivativeContext : public InstPassBase
//
bool processReferencedFunctions(IRBuilder* builder)
{
- List<IRForwardDifferentiate*> autoDiffWorkList;
+ List<IRInst*> autoDiffWorkList;
for (;;)
{
// Collect all `ForwardDifferentiate` insts from the module.
autoDiffWorkList.clear();
- processInstsOfType<IRForwardDifferentiate>(kIROp_ForwardDifferentiate, [&](IRForwardDifferentiate* fwdDiffInst)
- {
- autoDiffWorkList.add(fwdDiffInst);
- });
+ processAllInsts([&](IRInst* inst)
+ {
+ switch (inst->getOp())
+ {
+ case kIROp_ForwardDifferentiate:
+ case kIROp_BackwardDifferentiate:
+ autoDiffWorkList.add(inst);
+ break;
+ default:
+ break;
+ }
+ });
if (autoDiffWorkList.getCount() == 0)
break;
@@ -2086,42 +2752,59 @@ struct JVPDerivativeContext : public InstPassBase
// Process collected `ForwardDifferentiate` insts and replace them with placeholders for
// differentiated functions.
transcriberStorage.followUpFunctionsToTranscribe.clear();
+ backwardTranscriberStorage.followUpFunctionsToTranscribe.clear();
- for (auto fwdDiffInst : autoDiffWorkList)
+ for (auto differentiateInst : autoDiffWorkList)
{
- auto baseInst = fwdDiffInst->getBaseFn();
+ IRInst* baseInst = differentiateInst->getOperand(0);
+
if (auto baseFunction = as<IRGlobalValueWithCode>(baseInst))
{
- if (auto existingDiffFunc = lookupJVPReference(baseFunction))
+ if (as<IRForwardDifferentiate>(differentiateInst))
{
- fwdDiffInst->replaceUsesWith(existingDiffFunc);
- fwdDiffInst->removeAndDeallocate();
- }
- else if (isMarkedForForwardDifferentiation(baseFunction))
- {
- if (as<IRFunc>(baseFunction) || as<IRGeneric>(baseFunction))
+ if (auto existingDiffFunc = lookupJVPReference(baseFunction))
{
- IRInst* diffFunc = transcriberStorage.transcribe(builder, baseFunction);
- SLANG_ASSERT(diffFunc);
- fwdDiffInst->replaceUsesWith(diffFunc);
- fwdDiffInst->removeAndDeallocate();
+ differentiateInst->replaceUsesWith(existingDiffFunc);
+ differentiateInst->removeAndDeallocate();
}
- else
+ else if (isMarkedForForwardDifferentiation(baseFunction))
{
- // TODO(Sai): This would probably be better with a more specific
- // error code.
- getSink()->diagnose(fwdDiffInst->sourceLoc,
- Diagnostics::internalCompilerError,
- "Unexpected instruction. Expected func or generic");
+ if (as<IRFunc>(baseFunction) || as<IRGeneric>(baseFunction))
+ {
+ IRInst* diffFunc = transcriberStorage.transcribe(builder, baseFunction);
+ SLANG_ASSERT(diffFunc);
+ differentiateInst->replaceUsesWith(diffFunc);
+ differentiateInst->removeAndDeallocate();
+ }
+ else
+ {
+ getSink()->diagnose(differentiateInst->sourceLoc,
+ Diagnostics::internalCompilerError,
+ "Unexpected instruction. Expected func or generic");
+ }
}
}
- else
+ else if (as<IRBackwardDifferentiate>(differentiateInst))
{
- // TODO(Sai): This would probably be better with a more specific
- // error code.
- getSink()->diagnose(fwdDiffInst->sourceLoc,
- Diagnostics::internalCompilerError,
- "Cannot differentiate functions not marked for differentiation");
+ if (isMarkedForBackwardDifferentiation(baseFunction))
+ {
+ if (as<IRFunc>(baseFunction) || as<IRGeneric>(baseFunction))
+ {
+ IRInst* diffFunc =
+ backwardTranscriberStorage
+ .transcribeFuncHeader(builder, (IRFunc*)baseFunction)
+ .differential;
+ SLANG_ASSERT(diffFunc);
+ differentiateInst->replaceUsesWith(diffFunc);
+ differentiateInst->removeAndDeallocate();
+ }
+ else
+ {
+ getSink()->diagnose(differentiateInst->sourceLoc,
+ Diagnostics::internalCompilerError,
+ "Unexpected instruction. Expected func or generic");
+ }
+ }
}
}
}
@@ -2136,6 +2819,16 @@ struct JVPDerivativeContext : public InstPassBase
transcriberStorage.transcribeFunc(builder, primalFunc, diffFunc);
}
+ followUpWorkList = _Move(backwardTranscriberStorage.followUpFunctionsToTranscribe);
+ for (auto task : followUpWorkList)
+ {
+ auto diffFunc = as<IRFunc>(task.differential);
+ SLANG_ASSERT(diffFunc);
+ auto primalFunc = as<IRFunc>(task.primal);
+ SLANG_ASSERT(primalFunc);
+
+ backwardTranscriberStorage.transcribeFunc(builder, primalFunc, diffFunc);
+ }
// Transcribing the function body really shouldn't produce more follow up function body work.
// However it may produce new `ForwardDifferentiate` instructions, which we collect and process
@@ -2159,7 +2852,7 @@ struct JVPDerivativeContext : public InstPassBase
IRInst* lowerMakePair(IRBuilder* builder, IRInst* inst)
{
-
+
if (auto makePairInst = as<IRMakeDifferentialPair>(inst))
{
bool isTrivial = false;
@@ -2182,7 +2875,7 @@ struct JVPDerivativeContext : public InstPassBase
return result;
}
}
-
+
return nullptr;
}
@@ -2213,7 +2906,7 @@ struct JVPDerivativeContext : public InstPassBase
return primalFieldExtract;
}
}
-
+
return nullptr;
}
@@ -2426,7 +3119,7 @@ struct JVPDerivativeContext : public InstPassBase
//
bool isMarkedForForwardDifferentiation(IRGlobalValueWithCode* callable)
{
- for(auto decoration = callable->getFirstDecoration();
+ for (auto decoration = callable->getFirstDecoration();
decoration;
decoration = decoration->getNextDecoration())
{
@@ -2438,11 +3131,11 @@ struct JVPDerivativeContext : public InstPassBase
return false;
}
- IRStringLit* getForwardDerivativeFuncName(IRInst* func)
+ IRStringLit* getForwardDerivativeFuncName(IRInst* func)
{
IRBuilder builder(&sharedBuilderStorage);
builder.setInsertBefore(func);
-
+
IRStringLit* name = nullptr;
if (auto linkageDecoration = func->findDecoration<IRLinkageDecoration>())
{
@@ -2456,17 +3149,54 @@ struct JVPDerivativeContext : public InstPassBase
return name;
}
+ // Checks decorators to see if the function should
+ // be differentiated (kIROp_ForwardDifferentiableDecoration)
+ //
+ bool isMarkedForBackwardDifferentiation(IRGlobalValueWithCode* callable)
+ {
+ for (auto decoration = callable->getFirstDecoration();
+ decoration;
+ decoration = decoration->getNextDecoration())
+ {
+ if (decoration->getOp() == kIROp_BackwardDifferentiableDecoration)
+ {
+ return true;
+ }
+ }
+ return false;
+ }
+
+ IRStringLit* getBackwardDerivativeFuncName(IRInst* func)
+ {
+ IRBuilder builder(&sharedBuilderStorage);
+ builder.setInsertBefore(func);
+
+ IRStringLit* name = nullptr;
+ if (auto linkageDecoration = func->findDecoration<IRLinkageDecoration>())
+ {
+ name = builder.getStringValue((String(linkageDecoration->getMangledName()) + "_bwd_diff").getUnownedSlice());
+ }
+ else if (auto namehintDecoration = func->findDecoration<IRNameHintDecoration>())
+ {
+ name = builder.getStringValue((String(namehintDecoration->getName()) + "_bwd_diff").getUnownedSlice());
+ }
+
+ return name;
+ }
+
JVPDerivativeContext(IRModule* module, DiagnosticSink* sink) :
InstPassBase(module),
sink(sink),
autoDiffSharedContextStorage(module->getModuleInst()),
- transcriberStorage(&autoDiffSharedContextStorage, &sharedBuilderStorage)
+ transcriberStorage(&autoDiffSharedContextStorage, &sharedBuilderStorage),
+ backwardTranscriberStorage(&autoDiffSharedContextStorage, &sharedBuilderStorage, sink)
{
autoDiffSharedContextStorage.sharedBuilder = &sharedBuilderStorage;
pairBuilderStorage.sharedContext = &autoDiffSharedContextStorage;
transcriberStorage.sink = sink;
transcriberStorage.autoDiffSharedContext = &(autoDiffSharedContextStorage);
transcriberStorage.pairBuilder = &(pairBuilderStorage);
+ backwardTranscriberStorage.pairBuilder = &pairBuilderStorage;
}
protected:
@@ -2474,10 +3204,12 @@ protected:
// processing instructions while maintaining state.
//
JVPTranscriber transcriberStorage;
-
+
+ BackwardDiffTranscriber backwardTranscriberStorage;
+
// Diagnostic object from the compile request for
// error messages.
- DiagnosticSink* sink;
+ DiagnosticSink* sink;
// Context to find and manage the witness tables for types
// implementing `IDifferentiable`
@@ -2490,11 +3222,11 @@ protected:
// Set up context and call main process method.
//
-bool processForwardDifferentiableFuncs(
- IRModule* module,
- DiagnosticSink* sink,
- IRJVPDerivativePassOptions const&)
-{
+bool processDifferentiableFuncs(
+ IRModule* module,
+ DiagnosticSink* sink,
+ IRJVPDerivativePassOptions const&)
+{
// Simplify module to remove dead code.
IRDeadCodeEliminationOptions options;
options.keepExportsAlive = true;
diff --git a/source/slang/slang-ir-diff-jvp.h b/source/slang/slang-ir-diff-jvp.h
index 01ac15d6c..a866a3db3 100644
--- a/source/slang/slang-ir-diff-jvp.h
+++ b/source/slang/slang-ir-diff-jvp.h
@@ -13,7 +13,7 @@ namespace Slang
// Nothing for now..
};
- bool processForwardDifferentiableFuncs(
+ bool processDifferentiableFuncs(
IRModule* module,
DiagnosticSink* sink,
IRJVPDerivativePassOptions const& options = IRJVPDerivativePassOptions());
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index cb4854d7d..51811f59e 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -707,6 +707,14 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0)
/// generated derivative function.
INST(ForwardDerivativeDecoration, fwdDerivative, 1, 0)
+ /// Used by the auto-diff pass to hold a reference to the
+ /// generated derivative function.
+ INST(BackwardDifferentiableDecoration, backwardDifferentiable, 1, 0)
+
+ /// Used by the auto-diff pass to hold a reference to the
+ /// generated derivative function.
+ INST(BackwardDerivativeDecoration, backwardDiffReference, 1, 0)
+
/// Used by the auto-diff pass to hold a reference to a
/// differential member of a type in its associated differential type.
INST(DerivativeMemberDecoration, derivativeMemberDecoration, 1, 0)
@@ -763,6 +771,7 @@ INST(Reinterpret, reinterpret, 1, 0)
INST(CastPtrToBool, CastPtrToBool, 1, 0)
INST(IsType, IsType, 3, 0)
INST(ForwardDifferentiate, ForwardDifferentiate, 1, 0)
+INST(BackwardDifferentiate, BackwardDifferentiate, 1, 0)
INST(DifferentialEqualityTypeCast, DifferentialEqualityTypeCast, 1, 0)
// Converts other resources (such as ByteAddressBuffer) to the equivalent StructuredBuffer
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 5587a7c68..5eea12de8 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -566,6 +566,16 @@ struct IRForwardDerivativeDecoration : IRDecoration
IRInst* getForwardDerivativeFunc() { return getOperand(0); }
};
+struct IRBackwardDifferentiableDecoration : IRDecoration
+{
+ enum
+ {
+ kOp = kIROp_BackwardDifferentiableDecoration
+ };
+ IR_LEAF_ISA(BackwardDifferentiableDecoration)
+};
+
+
struct IRDerivativeMemberDecoration : IRDecoration
{
enum
@@ -592,6 +602,21 @@ struct IRForwardDifferentiate : IRInst
IR_LEAF_ISA(ForwardDifferentiate)
};
+// An instruction that replaces the function symbol
+// with it's derivative function.
+struct IRBackwardDifferentiate : IRInst
+{
+ enum
+ {
+ kOp = kIROp_BackwardDifferentiate
+ };
+ // The base function for the call.
+ IRUse base;
+ IRInst* getBaseFn() { return getOperand(0); }
+
+ IR_LEAF_ISA(BackwardDifferentiate)
+};
+
// Dictionary item mapping a type with a corresponding
// IDifferentiable witness table
//
@@ -2497,6 +2522,7 @@ public:
IRInst* existentialValue);
IRInst* emitForwardDifferentiateInst(IRType* type, IRInst* baseFn);
+ IRInst* emitBackwardDifferentiateInst(IRType* type, IRInst* baseFn);
IRInst* emitMakeDifferentialPair(IRType* type, IRInst* primal, IRInst* differential);
@@ -2999,6 +3025,7 @@ public:
IRInst* emitBitAnd(IRType* type, IRInst* left, IRInst* right);
IRInst* emitBitOr(IRType* type, IRInst* left, IRInst* right);
IRInst* emitBitNot(IRType* type, IRInst* value);
+ IRInst* emitNeg(IRType* type, IRInst* value);
IRInst* emitAdd(IRType* type, IRInst* left, IRInst* right);
IRInst* emitSub(IRType* type, IRInst* left, IRInst* right);
@@ -3207,11 +3234,21 @@ public:
addDecoration(value, kIROp_ForwardDifferentiableDecoration);
}
+ void addBackwardDifferentiableDecoration(IRInst* value)
+ {
+ addDecoration(value, kIROp_BackwardDifferentiableDecoration);
+ }
+
void addForwardDerivativeDecoration(IRInst* value, IRInst* fwdFunc)
{
addDecoration(value, kIROp_ForwardDerivativeDecoration, fwdFunc);
}
+ void addBackwardDerivativeDecoration(IRInst* value, IRInst* jvpFn)
+ {
+ addDecoration(value, kIROp_BackwardDerivativeDecoration, jvpFn);
+ }
+
void addCOMWitnessDecoration(IRInst* value, IRInst* witnessTable)
{
addDecoration(value, kIROp_COMWitnessDecoration, &witnessTable, 1);
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 9d538a774..6112beaf2 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -3109,6 +3109,17 @@ namespace Slang
return inst;
}
+ IRInst* IRBuilder::emitBackwardDifferentiateInst(IRType* type, IRInst* baseFn)
+ {
+ auto inst = createInst<IRBackwardDifferentiate>(
+ this,
+ kIROp_BackwardDifferentiate,
+ type,
+ baseFn);
+ addInst(inst);
+ return inst;
+ }
+
IRInst* IRBuilder::emitMakeDifferentialPair(IRType* type, IRInst* primal, IRInst* differential)
{
IRInst* args[] = {primal, differential};
@@ -4556,6 +4567,17 @@ namespace Slang
return inst;
}
+ IRInst* IRBuilder::emitNeg(IRType* type, IRInst* value)
+ {
+ auto inst = createInst<IRInst>(
+ this,
+ kIROp_Neg,
+ type,
+ value);
+ addInst(inst);
+ return inst;
+ }
+
IRInst* IRBuilder::emitAdd(IRType* type, IRInst* left, IRInst* right)
{
auto inst = createInst<IRInst>(
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 5930875f1..a0158cf38 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -3081,6 +3081,21 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
baseVal.val));
}
+ // Emit IR to denote the forward-mode derivative
+ // of the inner func-expr. This will be resolved
+ // to a concrete function during the derivative
+ // pass.
+ LoweredValInfo visitBackwardDifferentiateExpr(BackwardDifferentiateExpr* expr)
+ {
+ auto baseVal = lowerSubExpr(expr->baseFunction);
+ SLANG_ASSERT(baseVal.flavor == LoweredValInfo::Flavor::Simple);
+
+ return LoweredValInfo::simple(
+ getBuilder()->emitBackwardDifferentiateInst(
+ lowerType(context, expr->type),
+ baseVal.val));
+ }
+
LoweredValInfo visitGetArrayLengthExpr(GetArrayLengthExpr* expr)
{
auto baseVal = lowerSubExpr(expr->arrayExpr);
@@ -7799,6 +7814,10 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
{
getBuilder()->addForwardDifferentiableDecoration(irFunc);
}
+ if (decl->findModifier<BackwardDifferentiableAttribute>())
+ {
+ getBuilder()->addBackwardDifferentiableDecoration(irFunc);
+ }
if (auto differentialAttr = decl->findModifier<DifferentiableAttribute>())
{
lowerDifferentiableAttribute(subContext, irFunc, differentialAttr);
diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp
index 78edd4deb..d3dc5964e 100644
--- a/source/slang/slang-parser.cpp
+++ b/source/slang/slang-parser.cpp
@@ -2109,6 +2109,26 @@ namespace Slang
return parseForwardDifferentiate(parser);
}
+ /// Parse an expression of the form __bwd_diff(fn) where fn is an
+ /// identifier pointing to a function.
+ static Expr* parseBackwardDifferentiate(Parser* parser)
+ {
+ BackwardDifferentiateExpr* bwdDiffExpr = parser->astBuilder->create<BackwardDifferentiateExpr>();
+
+ parser->ReadToken(TokenType::LParent);
+
+ bwdDiffExpr->baseFunction = parser->ParseExpression();
+
+ parser->ReadToken(TokenType::RParent);
+
+ return bwdDiffExpr;
+ }
+
+ static NodeBase* parseBackwardDifferentiate(Parser* parser, void* /* unused */)
+ {
+ return parseBackwardDifferentiate(parser);
+ }
+
/// Parse a `This` type expression
static Expr* parseThisTypeExpr(Parser* parser)
{
@@ -6646,7 +6666,8 @@ namespace Slang
_makeParseExpr("none", parseNoneExpr),
_makeParseExpr("try", parseTryExpr),
_makeParseExpr("__TaggedUnion", parseTaggedUnionType),
- _makeParseExpr("__fwd_diff", parseForwardDifferentiate)
+ _makeParseExpr("__fwd_diff", parseForwardDifferentiate),
+ _makeParseExpr("__bwd_diff", parseBackwardDifferentiate)
};
ConstArrayView<SyntaxParseInfo> getSyntaxParseInfos()
diff --git a/source/slang/slang.natvis b/source/slang/slang.natvis
index 38244f9eb..13334e00c 100644
--- a/source/slang/slang.natvis
+++ b/source/slang/slang.natvis
@@ -235,6 +235,7 @@
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ExtractExistentialValueExpr">(Slang::ExtractExistentialValueExpr*)&amp;astNodeType</ExpandedItem>
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::OpenRefExpr">(Slang::OpenRefExpr*)&amp;astNodeType</ExpandedItem>
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ForwardDifferentiateExpr">(Slang::ForwardDifferentiateExpr*)&amp;astNodeType</ExpandedItem>
+ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::BackwardDifferentiateExpr">(Slang::BackwardDifferentiateExpr*)&amp;astNodeType</ExpandedItem>
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::TaggedUnionTypeExpr">(Slang::TaggedUnionTypeExpr*)&amp;astNodeType</ExpandedItem>
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ThisTypeExpr">(Slang::ThisTypeExpr*)&amp;astNodeType</ExpandedItem>
<ExpandedItem Condition="astNodeType == Slang::ASTNodeType::AndTypeExpr">(Slang::AndTypeExpr*)&amp;astNodeType</ExpandedItem>
diff --git a/tests/autodiff/backward-diff-smoke.slang b/tests/autodiff/backward-diff-smoke.slang
new file mode 100644
index 000000000..3d0d8970d
--- /dev/null
+++ b/tests/autodiff/backward-diff-smoke.slang
@@ -0,0 +1,25 @@
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+typedef DifferentialPair<float> dpfloat;
+typedef float.Differential dfloat;
+
+[BackwardDifferentiable]
+float test(float x, float y)
+{
+ return 2.0f * x + y * 4.0f - x;
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ dpfloat dpa = dpfloat(2.0, 0.0);
+ dpfloat dpb = dpfloat(1.5, 0.0);
+
+ __bwd_diff(test)(dpa, dpb, 1.0f);
+ outputBuffer[0] = dpa.d; // Expect: 1
+ outputBuffer[1] = dpb.d; // Expect: 4
+}
diff --git a/tests/autodiff/backward-diff-smoke.slang.expected.txt b/tests/autodiff/backward-diff-smoke.slang.expected.txt
new file mode 100644
index 000000000..8b514833c
--- /dev/null
+++ b/tests/autodiff/backward-diff-smoke.slang.expected.txt
@@ -0,0 +1,6 @@
+type: float
+1.0
+4.0
+0.0
+0.0
+0.0 \ No newline at end of file