diff options
| -rw-r--r-- | source/slang/diff.meta.slang | 3 | ||||
| -rw-r--r-- | source/slang/slang-ast-expr.h | 8 | ||||
| -rw-r--r-- | source/slang/slang-ast-modifier.h | 6 | ||||
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 59 | ||||
| -rw-r--r-- | source/slang/slang-check-impl.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-check-overload.cpp | 55 | ||||
| -rw-r--r-- | source/slang/slang-emit.cpp | 6 | ||||
| -rw-r--r-- | source/slang/slang-ir-diff-jvp.cpp | 824 | ||||
| -rw-r--r-- | source/slang/slang-ir-diff-jvp.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 9 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 37 | ||||
| -rw-r--r-- | source/slang/slang-ir.cpp | 22 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 19 | ||||
| -rw-r--r-- | source/slang/slang-parser.cpp | 23 | ||||
| -rw-r--r-- | source/slang/slang.natvis | 1 | ||||
| -rw-r--r-- | tests/autodiff/backward-diff-smoke.slang | 25 | ||||
| -rw-r--r-- | tests/autodiff/backward-diff-smoke.slang.expected.txt | 6 |
17 files changed, 1040 insertions, 67 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index 1f6064983..6f1008277 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -9,6 +9,9 @@ attribute_syntax [ForwardDifferentiable] : ForwardDifferentiableAttribute; __attributeTarget(FunctionDeclBase) attribute_syntax [ForwardDerivative(function)] : ForwardDerivativeAttribute; +__attributeTarget(FunctionDeclBase) +attribute_syntax [BackwardDifferentiable] : BackwardDifferentiableAttribute; + __attributeTarget(FunctionDeclBase) attribute_syntax [ForwardDerivativeOf(function)] : ForwardDerivativeOfAttribute; diff --git a/source/slang/slang-ast-expr.h b/source/slang/slang-ast-expr.h index ef6a05c71..86b72e05a 100644 --- a/source/slang/slang-ast-expr.h +++ b/source/slang/slang-ast-expr.h @@ -451,6 +451,14 @@ class ForwardDifferentiateExpr: public HigherOrderInvokeExpr SLANG_AST_CLASS(ForwardDifferentiateExpr) }; + /// An expression of the form `__bwd_diff(fn)` to access the + /// forward-mode derivative version of the function `fn` + /// +class BackwardDifferentiateExpr: public HigherOrderInvokeExpr +{ + SLANG_AST_CLASS(BackwardDifferentiateExpr) +}; + /// A type expression of the form `__TaggedUnion(A, ...)`. /// /// An expression of this form will resolve to a `TaggedUnionType` diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index d6a961328..8419facce 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -1045,6 +1045,12 @@ class ForwardDerivativeOfAttribute : public Attribute Expr* backDeclRef; // DeclRef to this derivative function when initiated from primalFunction. }; + /// The `[BackwardDifferentiable]` attribute indicates that a function can be backward-differentiated. +class BackwardDifferentiableAttribute : public DifferentiableAttribute +{ + SLANG_AST_CLASS(BackwardDifferentiableAttribute) +}; + /// Indicates that the modified declaration is one of the "magic" declarations /// that NVAPI uses to communicate extended operations. When NVAPI is being included /// via the prelude for downstream compilation, declarations with this modifier diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 30db9ecfa..f568dd8df 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -969,6 +969,14 @@ namespace Slang maybeRegisterDifferentiableType(getASTBuilder(), checkedTerm->type.type); } + // Differentiable type checking. + // TODO: This can be super slow. + if (this->m_parentFunc && + this->m_parentFunc->findModifier<BackwardDifferentiableAttribute>()) + { + maybeRegisterDifferentiableType(getASTBuilder(), checkedTerm->type.type); + } + return checkedTerm; } @@ -2027,6 +2035,45 @@ namespace Slang return jvpType; } + Type* SemanticsVisitor::processBackwardDiffFuncType(FuncType* originalType) + { + // Resolve backward diff type here. + // Note that this type checking needs to be in sync with + // the auto-generation logic in slang-ir-jvp-diff.cpp + + FuncType* type = m_astBuilder->create<FuncType>(); + + // The backward diff return type is void + // + type->resultType = m_astBuilder->getVoidType(); + + // No support for differentiating function that throw errors, for now. + SLANG_ASSERT(originalType->errorType->equals(m_astBuilder->getBottomType())); + type->errorType = originalType->errorType; + + for (UInt i = 0; i < originalType->getParamCount(); i++) + { + if (auto derivType = _toDifferentialParamType(originalType->getParamType(i))) + { + // Using inout type on all the derivative parameters + if (auto outType = as<OutType>(derivType)) + { + derivType = outType->getValueType(); + } + else if (!as<PtrTypeBase>(derivType)) + { + derivType = m_astBuilder->getInOutType(derivType); + } + type->paramTypes.add(derivType); + } + } + + // Last parameter is the initial derivative of the original return type + type->paramTypes.add(originalType->resultType); + + return type; + } + Expr* SemanticsExprVisitor::visitForwardDifferentiateExpr(ForwardDifferentiateExpr* expr) { // Check/Resolve inner function declaration. @@ -2039,6 +2086,18 @@ namespace Slang return expr; } + Expr* SemanticsExprVisitor::visitBackwardDifferentiateExpr(BackwardDifferentiateExpr* expr) + { + // Check/Resolve inner function declaration. + expr->baseFunction = CheckTerm(expr->baseFunction); + + // For now we only support using higher order expr as callee in an invoke expr. + // The actual type of the higher order function will be derived during resolve invoke. + expr->type = m_astBuilder->getBottomType(); + + return expr; + } + Expr* SemanticsExprVisitor::visitGetArrayLengthExpr(GetArrayLengthExpr* expr) { expr->arrayExpr = CheckTerm(expr->arrayExpr); diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 70b120518..e7681212f 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -714,6 +714,7 @@ namespace Slang // Convert a function's original type to it's JVP type. Type* processJVPFuncType(FuncType* originalType); + Type* processBackwardDiffFuncType(FuncType* originalType); /// Registers a type as conforming to IDifferentiable, along with a witness /// describing the relationship. @@ -1908,6 +1909,7 @@ namespace Slang Expr* visitModifiedTypeExpr(ModifiedTypeExpr* expr); Expr* visitForwardDifferentiateExpr(ForwardDifferentiateExpr* expr); + Expr* visitBackwardDifferentiateExpr(BackwardDifferentiateExpr* expr); Expr* visitGetArrayLengthExpr(GetArrayLengthExpr* expr); diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp index 38754d170..fe9de9433 100644 --- a/source/slang/slang-check-overload.cpp +++ b/source/slang/slang-check-overload.cpp @@ -1548,35 +1548,51 @@ namespace Slang // Lookup the higher order function and process types accordingly. In the future, // if there are enough varieties, we can have dispatch logic instead of an // if-else ladder. - if (auto jvpExpr = as<ForwardDifferentiateExpr>(funcExpr)) + if (auto expr = as<HigherOrderInvokeExpr>(funcExpr)) { - if (auto origFuncType = as<FuncType>(jvpExpr->baseFunction->type)) + if (auto origFuncType = as<FuncType>(expr->baseFunction->type)) { - // Case: __fwd_diff(name-resolved-to-decl-ref) - auto baseFuncDeclRef = as<DeclRefExpr>(jvpExpr->baseFunction)->declRef.as<CallableDecl>(); + auto baseFuncDeclRef = as<DeclRefExpr>(expr->baseFunction)->declRef.as<CallableDecl>(); SLANG_ASSERT(baseFuncDeclRef); OverloadCandidate candidate; candidate.flavor = OverloadCandidate::Flavor::Expr; - candidate.funcType = as<FuncType>(processJVPFuncType(origFuncType)); + if (auto fwdExpr = as<ForwardDifferentiateExpr>(expr)) + { + // Case: __fwd_diff(name-resolved-to-decl-ref) + candidate.funcType = as<FuncType>(processJVPFuncType(origFuncType)); + } + else if (auto bwdExpr = as<BackwardDifferentiateExpr>(expr)) + { + // Case: __bwd_diff(name-resolved-to-decl-ref) + candidate.funcType = as<FuncType>(processBackwardDiffFuncType(origFuncType)); + } candidate.resultType = candidate.funcType->getResultType(); candidate.item = LookupResultItem(baseFuncDeclRef); AddOverloadCandidate(context, candidate); } - else if (auto origOverloadedType = as<OverloadGroupType>(jvpExpr->baseFunction->type)) + else if (auto origOverloadedType = as<OverloadGroupType>(expr->baseFunction->type)) { - // Case: __fwd_diff(name-resolved-to-multiple-decl-ref) - if (auto overloadExpr = as<OverloadedExpr>(jvpExpr->baseFunction)) + if (auto overloadExpr = as<OverloadedExpr>(expr->baseFunction)) { for (auto item : overloadExpr->lookupResult2.items) { auto funcType = as<FuncType>(GetTypeForDeclRef(item.declRef, item.declRef.decl->loc)); if (!funcType) continue; - funcType = as<FuncType>(processJVPFuncType(funcType)); + if (auto fwdExpr = as<ForwardDifferentiateExpr>(expr)) + { + // Case: __fwd_diff(name-resolved-to-decl-ref) + funcType = as<FuncType>(processJVPFuncType(funcType)); + } + else if (auto bwdExpr = as<BackwardDifferentiateExpr>(expr)) + { + // Case: __bwd_diff(name-resolved-to-decl-ref) + funcType = as<FuncType>(processBackwardDiffFuncType(funcType)); + } if (!funcType) continue; OverloadCandidate candidate; @@ -1597,9 +1613,8 @@ namespace Slang funcExpr->type); } } - else if (auto baseFuncGenericDeclRef = as<DeclRefExpr>(jvpExpr->baseFunction)->declRef.as<GenericDecl>()) + else if (auto baseFuncGenericDeclRef = as<DeclRefExpr>(expr->baseFunction)->declRef.as<GenericDecl>()) { - // Case: __fwd_diff(name-resolved-to-generic-decl) // Get inner function DeclRef<Decl> unspecializedInnerRef = DeclRef<Decl>( @@ -1610,7 +1625,9 @@ namespace Slang auto funcType = getFuncType(this->getASTBuilder(), unspecializedInnerRef.as<CallableDecl>()); // Process func type to generate JVP func type. - auto jvpFuncType = as<FuncType>(processJVPFuncType(funcType)); + auto jvpFuncType = as<ForwardDifferentiateExpr>(expr) ? + as<FuncType>(processJVPFuncType(funcType)) : + as<FuncType>(processBackwardDiffFuncType(funcType)); // Extract parameter list from processed type. List<Type*> paramTypes; @@ -1634,8 +1651,18 @@ namespace Slang // in order to process the specialized version of the original func type. // This could potentially be a declRef.substitute(jvpFuncType) // - candidate.funcType = as<FuncType>(processJVPFuncType( - getFuncType(this->getASTBuilder(), innerRef.as<CallableDecl>()))); + if (auto fwdExpr = as<ForwardDifferentiateExpr>(expr)) + { + // Case: __fwd_diff(name-resolved-to-generic-decl) + candidate.funcType = as<FuncType>(processJVPFuncType( + getFuncType(this->getASTBuilder(), innerRef.as<CallableDecl>()))); + } + else if (auto bwdExpr = as<BackwardDifferentiateExpr>(expr)) + { + // Case: __bwd_diff(name-resolved-to-generic-decl) + candidate.funcType = as<FuncType>(processBackwardDiffFuncType( + getFuncType(this->getASTBuilder(), innerRef.as<CallableDecl>()))); + } candidate.resultType = candidate.funcType->getResultType(); candidate.item = LookupResultItem(innerRef); diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index de9d23e97..2478eccc2 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -376,11 +376,7 @@ Result linkAndOptimizeIR( dumpIRIfEnabled(codeGenContext, irModule, "BEFORE-AUTODIFF"); // Process higher-order calles to auto-diff passes. - // 1. Generate JVP code wherever necessary. (Linearization or "forward-mode" pass) - processForwardDifferentiableFuncs(irModule, sink); - - // 2. Transpose JVP to VJP code wherever needed. (Transposition or "reverse-mode" pass) - // processVJPDerivativeMarkers(module); // Disabled currently. No impl yet. + processDifferentiableFuncs(irModule, sink); stripAutoDiffDecorations(irModule); diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp index 4c7a132d0..7f5979a87 100644 --- a/source/slang/slang-ir-diff-jvp.cpp +++ b/source/slang/slang-ir-diff-jvp.cpp @@ -2017,6 +2017,664 @@ struct JVPTranscriber } }; + +struct BackwardDiffTranscriber +{ + + // Stores the mapping of arbitrary 'R-value' instructions to instructions that represent + // their differential values. + Dictionary<IRInst*, IRInst*> orginalToTranscribed; + + // Set of insts currently being transcribed. Used to avoid infinite loops. + HashSet<IRInst*> instsInProgress; + + // Cloning environment to hold mapping from old to new copies for the primal + // instructions. + IRCloneEnv cloneEnv; + + // Diagnostic sink for error messages. + DiagnosticSink* sink; + + // Type conformance information. + AutoDiffSharedContext* autoDiffSharedContext; + + // Builder to help with creating and accessing the 'DifferentiablePair<T>' struct + DifferentialPairTypeBuilder* pairBuilder; + + DifferentiableTypeConformanceContext differentiableTypeConformanceContext; + + List<InstPair> followUpFunctionsToTranscribe; + + // Map that stores the upper gradient given an IRInst* + Dictionary<IRInst*, List<IRInst*>> upperGradients; + Dictionary<IRInst*, IRInst*> primalToDiffPair; + + SharedIRBuilder* sharedBuilder; + // Witness table that `DifferentialBottom:IDifferential`. + IRWitnessTable* differentialBottomWitness = nullptr; + Dictionary<InstPair, IRInst*> differentialPairTypes; + + BackwardDiffTranscriber(AutoDiffSharedContext* shared, SharedIRBuilder* inSharedBuilder, DiagnosticSink* inSink) + : autoDiffSharedContext(shared) + , sink(inSink) + , differentiableTypeConformanceContext(shared) + , sharedBuilder(inSharedBuilder) + {} + + DiagnosticSink* getSink() + { + SLANG_ASSERT(sink); + return sink; + } + + IRFuncType* differentiateFunctionType(IRBuilder* builder, IRFuncType* funcType) + { + List<IRType*> newParameterTypes; + IRType* diffReturnType; + + for (UIndex i = 0; i < funcType->getParamCount(); i++) + { + auto origType = funcType->getParamType(i); + if (auto diffPairType = tryGetDiffPairType(builder, origType)) + { + auto inoutDiffPairType = builder->getPtrType(kIROp_InOutType, diffPairType); + newParameterTypes.add(inoutDiffPairType); + } + else + newParameterTypes.add(origType); + } + + newParameterTypes.add(funcType->getResultType()); + + diffReturnType = builder->getVoidType(); + + return builder->getFuncType(newParameterTypes, diffReturnType); + } + + IRWitnessTable* getDifferentialBottomWitness() + { + IRBuilder builder(sharedBuilder); + builder.setInsertInto(sharedBuilder->getModule()->getModuleInst()); + auto result = + as<IRWitnessTable>(differentiableTypeConformanceContext.lookUpConformanceForType( + builder.getDifferentialBottomType())); + SLANG_ASSERT(result); + return result; + } + + // Get or construct `:IDifferentiable` conformance for a DifferentiablePair. + IRWitnessTable* getDifferentialPairWitness(IRInst* inDiffPairType) + { + IRBuilder builder(sharedBuilder); + builder.setInsertInto(inDiffPairType->parent); + auto diffPairType = as<IRDifferentialPairType>(inDiffPairType); + SLANG_ASSERT(diffPairType); + auto result = + as<IRWitnessTable>(differentiableTypeConformanceContext.lookUpConformanceForType( + builder.getDifferentialBottomType())); + if (result) + return result; + + auto table = builder.createWitnessTable(autoDiffSharedContext->differentiableInterfaceType, diffPairType); + auto diffType = differentiateType(&builder, diffPairType->getValueType()); + auto differentialType = builder.getDifferentialPairType(diffType, getDifferentialBottomWitness()); + builder.createWitnessTableEntry(table, autoDiffSharedContext->differentialAssocTypeStructKey, differentialType); + // Omit the method synthesis here, since we can just intercept those directly at `getXXMethodForType`. + + differentiableTypeConformanceContext.differentiableWitnessDictionary[diffPairType] = table; + return table; + } + + IRType* getOrCreateDiffPairType(IRInst* primalType, IRInst* witness) + { + IRBuilder builder(sharedBuilder); + builder.setInsertInto(primalType->parent); + return builder.getDifferentialPairType( + (IRType*)primalType, + witness); + } + + IRType* getOrCreateDiffPairType(IRInst* primalType) + { + IRBuilder builder(sharedBuilder); + builder.setInsertInto(primalType->parent); + auto witness = as<IRWitnessTable>( + differentiableTypeConformanceContext.lookUpConformanceForType((IRType*)primalType)); + if (!witness) + witness = getDifferentialBottomWitness(); + return builder.getDifferentialPairType( + (IRType*)primalType, + witness); + } + + IRType* differentiateType(IRBuilder* builder, IRType* origType) + { + IRInst* diffType = nullptr; + if (!orginalToTranscribed.TryGetValue(origType, diffType)) + { + diffType = _differentiateTypeImpl(builder, origType); + orginalToTranscribed[origType] = diffType; + } + return (IRType*)diffType; + } + + IRType* _differentiateTypeImpl(IRBuilder* builder, IRType* origType) + { + if (auto ptrType = as<IRPtrTypeBase>(origType)) + return builder->getPtrType( + origType->getOp(), + differentiateType(builder, ptrType->getValueType())); + + // If there is an explicit primal version of this type in the local scope, load that + // otherwise use the original type. + // + IRInst* primalType = origType; + + // Special case certain compound types (PtrType, FuncType, etc..) + // otherwise try to lookup a differential definition for the given type. + // If one does not exist, then we assume it's not differentiable. + // + switch (primalType->getOp()) + { + case kIROp_Param: + if (as<IRTypeType>(primalType->getDataType())) + return (IRType*)(differentiableTypeConformanceContext.getDifferentialForType( + builder, + (IRType*)primalType)); + else if (as<IRWitnessTableType>(primalType->getDataType())) + return (IRType*)primalType; + + case kIROp_ArrayType: + { + auto primalArrayType = as<IRArrayType>(primalType); + if (auto diffElementType = differentiateType(builder, primalArrayType->getElementType())) + return builder->getArrayType( + diffElementType, + primalArrayType->getElementCount()); + else + return nullptr; + } + + case kIROp_DifferentialPairType: + { + auto primalPairType = as<IRDifferentialPairType>(primalType); + return getOrCreateDiffPairType( + pairBuilder->getDiffTypeFromPairType(builder, primalPairType), + pairBuilder->getDiffTypeWitnessFromPairType(builder, primalPairType)); + } + + case kIROp_FuncType: + return differentiateFunctionType(builder, as<IRFuncType>(primalType)); + + case kIROp_OutType: + if (auto diffValueType = differentiateType(builder, as<IROutType>(primalType)->getValueType())) + return builder->getOutType(diffValueType); + else + return nullptr; + + case kIROp_InOutType: + if (auto diffValueType = differentiateType(builder, as<IRInOutType>(primalType)->getValueType())) + return builder->getInOutType(diffValueType); + else + return nullptr; + + case kIROp_TupleType: + { + auto tupleType = as<IRTupleType>(primalType); + List<IRType*> diffTypeList; + // TODO: what if we have type parameters here? + for (UIndex ii = 0; ii < tupleType->getOperandCount(); ii++) + diffTypeList.add( + differentiateType(builder, (IRType*)tupleType->getOperand(ii))); + + return builder->getTupleType(diffTypeList); + } + + default: + return (IRType*)(differentiableTypeConformanceContext.getDifferentialForType(builder, (IRType*)primalType)); + } + } + + IRType* tryGetDiffPairType(IRBuilder* builder, IRType* primalType) + { + // If this is a PtrType (out, inout, etc..), then create diff pair from + // value type and re-apply the appropropriate PtrType wrapper. + // + if (auto origPtrType = as<IRPtrTypeBase>(primalType)) + { + if (auto diffPairValueType = tryGetDiffPairType(builder, origPtrType->getValueType())) + return builder->getPtrType(primalType->getOp(), diffPairValueType); + else + return nullptr; + } + auto diffType = differentiateType(builder, primalType); + if (diffType) + return (IRType*)getOrCreateDiffPairType(primalType); + return nullptr; + } + + InstPair transcribeParam(IRBuilder* builder, IRParam* origParam) + { + auto primalDataType = origParam->getDataType(); + // Do not differentiate generic type (and witness table) parameters + if (as<IRTypeType>(primalDataType) || as<IRWitnessTableType>(primalDataType)) + { + return InstPair( + cloneInst(&cloneEnv, builder, origParam), + nullptr); + } + + if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)primalDataType)) + { + IRInst* diffPairParam = builder->emitParam(diffPairType); + + auto diffPairVarName = makeDiffPairName(origParam); + if (diffPairVarName.getLength() > 0) + builder->addNameHintDecoration(diffPairParam, diffPairVarName.getUnownedSlice()); + + SLANG_ASSERT(diffPairParam); + + if (auto pairType = as<IRDifferentialPairType>(diffPairParam->getDataType())) + { + return InstPair( + builder->emitDifferentialPairGetPrimal(diffPairParam), + builder->emitDifferentialPairGetDifferential( + (IRType*)pairBuilder->getDiffTypeFromPairType(builder, pairType), + diffPairParam)); + } + // If this is an `in/inout DifferentialPair<>` parameter, we can't produce + // its primal and diff parts right now because they would represent a reference + // to a pair field, which doesn't make sense since pair types are considered mutable. + // We encode the result as if the param is non-differentiable, and handle it + // with special care at load/store. + return InstPair(diffPairParam, nullptr); + } + + + return InstPair( + cloneInst(&cloneEnv, builder, origParam), + nullptr); + } + + // Returns "dp<var-name>" to use as a name hint for parameters. + // If no primal name is available, returns a blank string. + // + String makeDiffPairName(IRInst* origVar) + { + if (auto namehintDecoration = origVar->findDecoration<IRNameHintDecoration>()) + { + return ("dp" + String(namehintDecoration->getName())); + } + + return String(""); + } + + + // In differential computation, the 'default' differential value is always zero. + // This is a consequence of differential computing being inherently linear. As a + // result, it's useful to have a method to generate zero literals of any (arithmetic) type. + // The current implementation requires that types are defined linearly. + // + IRInst* getDifferentialZeroOfType(IRBuilder* builder, IRType* primalType) + { + if (auto diffType = differentiateType(builder, primalType)) + { + switch (diffType->getOp()) + { + case kIROp_DifferentialPairType: + return builder->emitMakeDifferentialPair( + diffType, + getDifferentialZeroOfType(builder, as<IRDifferentialPairType>(diffType)->getValueType()), + getDifferentialZeroOfType(builder, as<IRDifferentialPairType>(diffType)->getValueType())); + } + // Since primalType has a corresponding differential type, we can lookup the + // definition for zero(). + auto zeroMethod = differentiableTypeConformanceContext.getZeroMethodForType(builder, primalType); + SLANG_ASSERT(zeroMethod); + + auto emptyArgList = List<IRInst*>(); + return builder->emitCallInst((IRType*)diffType, zeroMethod, emptyArgList); + } + else + { + if (isScalarIntegerType(primalType)) + { + return builder->getIntValue(primalType, 0); + } + + getSink()->diagnose(primalType->sourceLoc, + Diagnostics::internalCompilerError, + "could not generate zero value for given type"); + return nullptr; + } + } + + InstPair transcribeBlock(IRBuilder* builder, IRBlock* origBlock) + { + IRBuilder subBuilder(builder->getSharedBuilder()); + subBuilder.setInsertLoc(builder->getInsertLoc()); + + IRBlock* diffBlock = subBuilder.emitBlock(); + + subBuilder.setInsertInto(diffBlock); + + // First transcribe every parameter in the block. + for (auto param = origBlock->getFirstParam(); param; param = param->getNextParam()) + this->copyParam(&subBuilder, param); + + // The extra param for input gradient + auto gradParam = subBuilder.emitParam(as<IRFuncType>(origBlock->getParent()->getFullType())->getResultType()); + + // Then, run through every instruction and use the transcriber to generate the appropriate + // derivative code. + // + for (auto child = origBlock->getFirstOrdinaryInst(); child; child = child->getNextInst()) + this->copyInst(&subBuilder, child); + + auto lastInst = diffBlock->getLastOrdinaryInst(); + List<IRInst*> grads = { gradParam }; + upperGradients.Add(lastInst, grads); + for (auto child = diffBlock->getLastOrdinaryInst(); child; child = child->getPrevInst()) + { + auto upperGrads = upperGradients.TryGetValue(child); + if (!upperGrads) + continue; + if (upperGrads->getCount() > 1) + { + auto sumGrad = upperGrads->getFirst(); + for (auto i = 1; i < upperGrads->getCount(); i++) + { + sumGrad = subBuilder.emitAdd(sumGrad->getDataType(), sumGrad, (*upperGrads)[i]); + } + this->transcribeInstBackward(&subBuilder, child, sumGrad); + } + else + this->transcribeInstBackward(&subBuilder, child, upperGrads->getFirst()); + } + + subBuilder.emitReturn(); + + return InstPair(diffBlock, diffBlock); + } + + // Create an empty func to represent the transcribed func of `origFunc`. + InstPair transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc) + { + IRBuilder builder(inBuilder->getSharedBuilder()); + builder.setInsertBefore(origFunc); + + IRFunc* primalFunc = origFunc; + + differentiableTypeConformanceContext.setFunc(origFunc); + + primalFunc = origFunc; + + auto diffFunc = builder.createFunc(); + + SLANG_ASSERT(as<IRFuncType>(origFunc->getFullType())); + IRType* diffFuncType = this->differentiateFunctionType( + &builder, + as<IRFuncType>(origFunc->getFullType())); + diffFunc->setFullType(diffFuncType); + + if (auto nameHint = origFunc->findDecoration<IRNameHintDecoration>()) + { + auto originalName = nameHint->getName(); + StringBuilder newNameSb; + newNameSb << "s_bwd_" << originalName; + builder.addNameHintDecoration(diffFunc, newNameSb.getUnownedSlice()); + } + builder.addBackwardDerivativeDecoration(origFunc, diffFunc); + + // Mark the generated derivative function itself as differentiable. + builder.addBackwardDifferentiableDecoration(diffFunc); + + // Find and clone `DifferentiableTypeDictionaryDecoration` to the new diffFunc. + if (auto dictDecor = origFunc->findDecoration<IRDifferentiableTypeDictionaryDecoration>()) + { + cloneDecoration(dictDecor, diffFunc); + } + + auto result = InstPair(primalFunc, diffFunc); + followUpFunctionsToTranscribe.add(result); + return result; + } + + // Transcribe a function definition. + InstPair transcribeFunc(IRBuilder* inBuilder, IRFunc* primalFunc, IRFunc* diffFunc) + { + IRBuilder builder(inBuilder->getSharedBuilder()); + builder.setInsertInto(diffFunc); + + differentiableTypeConformanceContext.setFunc(primalFunc); + // Transcribe children from origFunc into diffFunc + for (auto block = primalFunc->getFirstBlock(); block; block = block->getNextBlock()) + this->transcribeBlock(&builder, block); + + return InstPair(primalFunc, diffFunc); + } + + IRInst* copyParam(IRBuilder* builder, IRParam* origParam) + { + auto primalDataType = origParam->getDataType(); + + if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)primalDataType)) + { + auto inoutDiffPairType = builder->getPtrType(kIROp_InOutType, diffPairType); + IRInst* diffParam = builder->emitParam(inoutDiffPairType); + + auto diffPairVarName = makeDiffPairName(origParam); + if (diffPairVarName.getLength() > 0) + builder->addNameHintDecoration(diffParam, diffPairVarName.getUnownedSlice()); + + SLANG_ASSERT(diffParam); + auto paramValue = builder->emitLoad(diffParam); + auto primal = builder->emitDifferentialPairGetPrimal(paramValue); + orginalToTranscribed.Add(origParam, primal); + primalToDiffPair.Add(primal, diffParam); + + return diffParam; + } + + + return cloneInst(&cloneEnv, builder, origParam); + } + + InstPair copyBinaryArith(IRBuilder* builder, IRInst* origArith) + { + SLANG_ASSERT(origArith->getOperandCount() == 2); + + auto origLeft = origArith->getOperand(0); + auto origRight = origArith->getOperand(1); + + IRInst* primalLeft; + if (!orginalToTranscribed.TryGetValue(origLeft, primalLeft)) + { + primalLeft = origLeft; + } + IRInst* primalRight; + if (!orginalToTranscribed.TryGetValue(origRight, primalRight)) + { + primalRight = origRight; + } + + auto resultType = origArith->getDataType(); + IRInst* newInst; + switch (origArith->getOp()) + { + case kIROp_Add: + newInst = builder->emitAdd(resultType, primalLeft, primalRight); + break; + case kIROp_Mul: + newInst = builder->emitMul(resultType, primalLeft, primalRight); + break; + case kIROp_Sub: + newInst = builder->emitSub(resultType, primalLeft, primalRight); + break; + case kIROp_Div: + newInst = builder->emitDiv(resultType, primalLeft, primalRight); + break; + default: + getSink()->diagnose(origArith->sourceLoc, + Diagnostics::unimplemented, + "this arithmetic instruction cannot be differentiated"); + } + orginalToTranscribed.Add(origArith, newInst); + return InstPair(newInst, nullptr); + } + + IRInst* transcribeBinaryArithBackward(IRBuilder* builder, IRInst* origArith, IRInst* grad) + { + SLANG_ASSERT(origArith->getOperandCount() == 2); + + auto lhs = origArith->getOperand(0); + auto rhs = origArith->getOperand(1); + + if (as<IRInOutType>(lhs->getDataType())) + { + lhs = builder->emitLoad(lhs); + lhs = builder->emitDifferentialPairGetPrimal(lhs); + } + if (as<IRInOutType>(rhs->getDataType())) + { + rhs = builder->emitLoad(rhs); + rhs = builder->emitDifferentialPairGetPrimal(rhs); + } + + IRInst* leftGrad; + IRInst* rightGrad; + + + switch (origArith->getOp()) + { + case kIROp_Add: + leftGrad = grad; + rightGrad = grad; + break; + case kIROp_Mul: + leftGrad = builder->emitMul(grad->getDataType(), rhs, grad); + rightGrad = builder->emitMul(grad->getDataType(), lhs, grad); + break; + case kIROp_Sub: + leftGrad = grad; + rightGrad = builder->emitNeg(grad->getDataType(), grad); + break; + case kIROp_Div: + leftGrad = builder->emitMul(grad->getDataType(), rhs, grad); + rightGrad = builder->emitMul(grad->getDataType(), lhs, grad); // TODO 1.0 / Grad + break; + default: + getSink()->diagnose(origArith->sourceLoc, + Diagnostics::unimplemented, + "this arithmetic instruction cannot be differentiated"); + } + + lhs = origArith->getOperand(0); + rhs = origArith->getOperand(1); + if (auto leftGrads = upperGradients.TryGetValue(lhs)) + { + leftGrads->add(leftGrad); + } + else + { + upperGradients.Add(lhs, leftGrad); + } + if (auto rightGrads = upperGradients.TryGetValue(rhs)) + { + rightGrads->add(rightGrad); + } + else + { + upperGradients.Add(rhs, rightGrad); + } + + return nullptr; + } + + InstPair copyInst(IRBuilder* builder, IRInst* origInst) + { + // Handle common SSA-style operations + switch (origInst->getOp()) + { + case kIROp_Param: + return transcribeParam(builder, as<IRParam>(origInst)); + + case kIROp_Return: + return InstPair(nullptr, nullptr); + + case kIROp_Add: + case kIROp_Mul: + case kIROp_Sub: + case kIROp_Div: + return copyBinaryArith(builder, origInst); + + default: + // Not yet implemented + SLANG_ASSERT(0); + } + + return InstPair(nullptr, nullptr); + } + + IRInst* transcribeParamBackward(IRBuilder* builder, IRInst* param, IRInst* grad) + { + IRInOutType* inoutParam = as<IRInOutType>(param->getDataType()); + auto pairType = as<IRDifferentialPairType>(inoutParam->getValueType()); + auto paramValue = builder->emitLoad(param); + auto primal = builder->emitDifferentialPairGetPrimal(paramValue); + auto diff = builder->emitDifferentialPairGetDifferential( + (IRType*)pairBuilder->getDiffTypeFromPairType(builder, pairType), + paramValue + ); + auto newDiff = builder->emitAdd(grad->getDataType(), diff, grad); + auto updatedParam = builder->emitMakeDifferentialPair(pairType, primal, newDiff); + auto store = builder->emitStore(param, updatedParam); + + return store; + } + + IRInst* transcribeInstBackward(IRBuilder* builder, IRInst* origInst, IRInst* grad) + { + // Handle common SSA-style operations + switch (origInst->getOp()) + { + case kIROp_Param: + return transcribeParamBackward(builder, as<IRParam>(origInst), grad); + + case kIROp_Add: + case kIROp_Mul: + case kIROp_Sub: + case kIROp_Div: + return transcribeBinaryArithBackward(builder, origInst, grad); + + case kIROp_DifferentialPairGetPrimal: + { + if (auto param = primalToDiffPair.TryGetValue(origInst)) + { + if (auto leftGrads = upperGradients.TryGetValue(*param)) + { + leftGrads->add(grad); + } + else + { + upperGradients.Add(*param, grad); + } + } + else + SLANG_ASSERT(0); + return nullptr; + } + + default: + // Not yet implemented + SLANG_ASSERT(0); + } + + return nullptr; + } +}; + + struct JVPDerivativeContext : public InstPassBase { @@ -2034,7 +2692,7 @@ struct JVPDerivativeContext : public InstPassBase SharedIRBuilder* sharedBuilder = &sharedBuilderStorage; sharedBuilder->init(module); sharedBuilder->deduplicateAndRebuildGlobalNumberingMap(); - + IRBuilder builderStorage(sharedBuilderStorage); IRBuilder* builder = &builderStorage; @@ -2059,7 +2717,7 @@ struct JVPDerivativeContext : public InstPassBase IRInst* lookupJVPReference(IRInst* primalFunction) { - if(auto jvpDefinition = primalFunction->findDecoration<IRForwardDerivativeDecoration>()) + if (auto jvpDefinition = primalFunction->findDecoration<IRForwardDerivativeDecoration>()) return jvpDefinition->getForwardDerivativeFunc(); return nullptr; } @@ -2069,16 +2727,24 @@ struct JVPDerivativeContext : public InstPassBase // bool processReferencedFunctions(IRBuilder* builder) { - List<IRForwardDifferentiate*> autoDiffWorkList; + List<IRInst*> autoDiffWorkList; for (;;) { // Collect all `ForwardDifferentiate` insts from the module. autoDiffWorkList.clear(); - processInstsOfType<IRForwardDifferentiate>(kIROp_ForwardDifferentiate, [&](IRForwardDifferentiate* fwdDiffInst) - { - autoDiffWorkList.add(fwdDiffInst); - }); + processAllInsts([&](IRInst* inst) + { + switch (inst->getOp()) + { + case kIROp_ForwardDifferentiate: + case kIROp_BackwardDifferentiate: + autoDiffWorkList.add(inst); + break; + default: + break; + } + }); if (autoDiffWorkList.getCount() == 0) break; @@ -2086,42 +2752,59 @@ struct JVPDerivativeContext : public InstPassBase // Process collected `ForwardDifferentiate` insts and replace them with placeholders for // differentiated functions. transcriberStorage.followUpFunctionsToTranscribe.clear(); + backwardTranscriberStorage.followUpFunctionsToTranscribe.clear(); - for (auto fwdDiffInst : autoDiffWorkList) + for (auto differentiateInst : autoDiffWorkList) { - auto baseInst = fwdDiffInst->getBaseFn(); + IRInst* baseInst = differentiateInst->getOperand(0); + if (auto baseFunction = as<IRGlobalValueWithCode>(baseInst)) { - if (auto existingDiffFunc = lookupJVPReference(baseFunction)) + if (as<IRForwardDifferentiate>(differentiateInst)) { - fwdDiffInst->replaceUsesWith(existingDiffFunc); - fwdDiffInst->removeAndDeallocate(); - } - else if (isMarkedForForwardDifferentiation(baseFunction)) - { - if (as<IRFunc>(baseFunction) || as<IRGeneric>(baseFunction)) + if (auto existingDiffFunc = lookupJVPReference(baseFunction)) { - IRInst* diffFunc = transcriberStorage.transcribe(builder, baseFunction); - SLANG_ASSERT(diffFunc); - fwdDiffInst->replaceUsesWith(diffFunc); - fwdDiffInst->removeAndDeallocate(); + differentiateInst->replaceUsesWith(existingDiffFunc); + differentiateInst->removeAndDeallocate(); } - else + else if (isMarkedForForwardDifferentiation(baseFunction)) { - // TODO(Sai): This would probably be better with a more specific - // error code. - getSink()->diagnose(fwdDiffInst->sourceLoc, - Diagnostics::internalCompilerError, - "Unexpected instruction. Expected func or generic"); + if (as<IRFunc>(baseFunction) || as<IRGeneric>(baseFunction)) + { + IRInst* diffFunc = transcriberStorage.transcribe(builder, baseFunction); + SLANG_ASSERT(diffFunc); + differentiateInst->replaceUsesWith(diffFunc); + differentiateInst->removeAndDeallocate(); + } + else + { + getSink()->diagnose(differentiateInst->sourceLoc, + Diagnostics::internalCompilerError, + "Unexpected instruction. Expected func or generic"); + } } } - else + else if (as<IRBackwardDifferentiate>(differentiateInst)) { - // TODO(Sai): This would probably be better with a more specific - // error code. - getSink()->diagnose(fwdDiffInst->sourceLoc, - Diagnostics::internalCompilerError, - "Cannot differentiate functions not marked for differentiation"); + if (isMarkedForBackwardDifferentiation(baseFunction)) + { + if (as<IRFunc>(baseFunction) || as<IRGeneric>(baseFunction)) + { + IRInst* diffFunc = + backwardTranscriberStorage + .transcribeFuncHeader(builder, (IRFunc*)baseFunction) + .differential; + SLANG_ASSERT(diffFunc); + differentiateInst->replaceUsesWith(diffFunc); + differentiateInst->removeAndDeallocate(); + } + else + { + getSink()->diagnose(differentiateInst->sourceLoc, + Diagnostics::internalCompilerError, + "Unexpected instruction. Expected func or generic"); + } + } } } } @@ -2136,6 +2819,16 @@ struct JVPDerivativeContext : public InstPassBase transcriberStorage.transcribeFunc(builder, primalFunc, diffFunc); } + followUpWorkList = _Move(backwardTranscriberStorage.followUpFunctionsToTranscribe); + for (auto task : followUpWorkList) + { + auto diffFunc = as<IRFunc>(task.differential); + SLANG_ASSERT(diffFunc); + auto primalFunc = as<IRFunc>(task.primal); + SLANG_ASSERT(primalFunc); + + backwardTranscriberStorage.transcribeFunc(builder, primalFunc, diffFunc); + } // Transcribing the function body really shouldn't produce more follow up function body work. // However it may produce new `ForwardDifferentiate` instructions, which we collect and process @@ -2159,7 +2852,7 @@ struct JVPDerivativeContext : public InstPassBase IRInst* lowerMakePair(IRBuilder* builder, IRInst* inst) { - + if (auto makePairInst = as<IRMakeDifferentialPair>(inst)) { bool isTrivial = false; @@ -2182,7 +2875,7 @@ struct JVPDerivativeContext : public InstPassBase return result; } } - + return nullptr; } @@ -2213,7 +2906,7 @@ struct JVPDerivativeContext : public InstPassBase return primalFieldExtract; } } - + return nullptr; } @@ -2426,7 +3119,7 @@ struct JVPDerivativeContext : public InstPassBase // bool isMarkedForForwardDifferentiation(IRGlobalValueWithCode* callable) { - for(auto decoration = callable->getFirstDecoration(); + for (auto decoration = callable->getFirstDecoration(); decoration; decoration = decoration->getNextDecoration()) { @@ -2438,11 +3131,11 @@ struct JVPDerivativeContext : public InstPassBase return false; } - IRStringLit* getForwardDerivativeFuncName(IRInst* func) + IRStringLit* getForwardDerivativeFuncName(IRInst* func) { IRBuilder builder(&sharedBuilderStorage); builder.setInsertBefore(func); - + IRStringLit* name = nullptr; if (auto linkageDecoration = func->findDecoration<IRLinkageDecoration>()) { @@ -2456,17 +3149,54 @@ struct JVPDerivativeContext : public InstPassBase return name; } + // Checks decorators to see if the function should + // be differentiated (kIROp_ForwardDifferentiableDecoration) + // + bool isMarkedForBackwardDifferentiation(IRGlobalValueWithCode* callable) + { + for (auto decoration = callable->getFirstDecoration(); + decoration; + decoration = decoration->getNextDecoration()) + { + if (decoration->getOp() == kIROp_BackwardDifferentiableDecoration) + { + return true; + } + } + return false; + } + + IRStringLit* getBackwardDerivativeFuncName(IRInst* func) + { + IRBuilder builder(&sharedBuilderStorage); + builder.setInsertBefore(func); + + IRStringLit* name = nullptr; + if (auto linkageDecoration = func->findDecoration<IRLinkageDecoration>()) + { + name = builder.getStringValue((String(linkageDecoration->getMangledName()) + "_bwd_diff").getUnownedSlice()); + } + else if (auto namehintDecoration = func->findDecoration<IRNameHintDecoration>()) + { + name = builder.getStringValue((String(namehintDecoration->getName()) + "_bwd_diff").getUnownedSlice()); + } + + return name; + } + JVPDerivativeContext(IRModule* module, DiagnosticSink* sink) : InstPassBase(module), sink(sink), autoDiffSharedContextStorage(module->getModuleInst()), - transcriberStorage(&autoDiffSharedContextStorage, &sharedBuilderStorage) + transcriberStorage(&autoDiffSharedContextStorage, &sharedBuilderStorage), + backwardTranscriberStorage(&autoDiffSharedContextStorage, &sharedBuilderStorage, sink) { autoDiffSharedContextStorage.sharedBuilder = &sharedBuilderStorage; pairBuilderStorage.sharedContext = &autoDiffSharedContextStorage; transcriberStorage.sink = sink; transcriberStorage.autoDiffSharedContext = &(autoDiffSharedContextStorage); transcriberStorage.pairBuilder = &(pairBuilderStorage); + backwardTranscriberStorage.pairBuilder = &pairBuilderStorage; } protected: @@ -2474,10 +3204,12 @@ protected: // processing instructions while maintaining state. // JVPTranscriber transcriberStorage; - + + BackwardDiffTranscriber backwardTranscriberStorage; + // Diagnostic object from the compile request for // error messages. - DiagnosticSink* sink; + DiagnosticSink* sink; // Context to find and manage the witness tables for types // implementing `IDifferentiable` @@ -2490,11 +3222,11 @@ protected: // Set up context and call main process method. // -bool processForwardDifferentiableFuncs( - IRModule* module, - DiagnosticSink* sink, - IRJVPDerivativePassOptions const&) -{ +bool processDifferentiableFuncs( + IRModule* module, + DiagnosticSink* sink, + IRJVPDerivativePassOptions const&) +{ // Simplify module to remove dead code. IRDeadCodeEliminationOptions options; options.keepExportsAlive = true; diff --git a/source/slang/slang-ir-diff-jvp.h b/source/slang/slang-ir-diff-jvp.h index 01ac15d6c..a866a3db3 100644 --- a/source/slang/slang-ir-diff-jvp.h +++ b/source/slang/slang-ir-diff-jvp.h @@ -13,7 +13,7 @@ namespace Slang // Nothing for now.. }; - bool processForwardDifferentiableFuncs( + bool processDifferentiableFuncs( IRModule* module, DiagnosticSink* sink, IRJVPDerivativePassOptions const& options = IRJVPDerivativePassOptions()); diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index cb4854d7d..51811f59e 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -707,6 +707,14 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) /// generated derivative function. INST(ForwardDerivativeDecoration, fwdDerivative, 1, 0) + /// Used by the auto-diff pass to hold a reference to the + /// generated derivative function. + INST(BackwardDifferentiableDecoration, backwardDifferentiable, 1, 0) + + /// Used by the auto-diff pass to hold a reference to the + /// generated derivative function. + INST(BackwardDerivativeDecoration, backwardDiffReference, 1, 0) + /// Used by the auto-diff pass to hold a reference to a /// differential member of a type in its associated differential type. INST(DerivativeMemberDecoration, derivativeMemberDecoration, 1, 0) @@ -763,6 +771,7 @@ INST(Reinterpret, reinterpret, 1, 0) INST(CastPtrToBool, CastPtrToBool, 1, 0) INST(IsType, IsType, 3, 0) INST(ForwardDifferentiate, ForwardDifferentiate, 1, 0) +INST(BackwardDifferentiate, BackwardDifferentiate, 1, 0) INST(DifferentialEqualityTypeCast, DifferentialEqualityTypeCast, 1, 0) // Converts other resources (such as ByteAddressBuffer) to the equivalent StructuredBuffer diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 5587a7c68..5eea12de8 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -566,6 +566,16 @@ struct IRForwardDerivativeDecoration : IRDecoration IRInst* getForwardDerivativeFunc() { return getOperand(0); } }; +struct IRBackwardDifferentiableDecoration : IRDecoration +{ + enum + { + kOp = kIROp_BackwardDifferentiableDecoration + }; + IR_LEAF_ISA(BackwardDifferentiableDecoration) +}; + + struct IRDerivativeMemberDecoration : IRDecoration { enum @@ -592,6 +602,21 @@ struct IRForwardDifferentiate : IRInst IR_LEAF_ISA(ForwardDifferentiate) }; +// An instruction that replaces the function symbol +// with it's derivative function. +struct IRBackwardDifferentiate : IRInst +{ + enum + { + kOp = kIROp_BackwardDifferentiate + }; + // The base function for the call. + IRUse base; + IRInst* getBaseFn() { return getOperand(0); } + + IR_LEAF_ISA(BackwardDifferentiate) +}; + // Dictionary item mapping a type with a corresponding // IDifferentiable witness table // @@ -2497,6 +2522,7 @@ public: IRInst* existentialValue); IRInst* emitForwardDifferentiateInst(IRType* type, IRInst* baseFn); + IRInst* emitBackwardDifferentiateInst(IRType* type, IRInst* baseFn); IRInst* emitMakeDifferentialPair(IRType* type, IRInst* primal, IRInst* differential); @@ -2999,6 +3025,7 @@ public: IRInst* emitBitAnd(IRType* type, IRInst* left, IRInst* right); IRInst* emitBitOr(IRType* type, IRInst* left, IRInst* right); IRInst* emitBitNot(IRType* type, IRInst* value); + IRInst* emitNeg(IRType* type, IRInst* value); IRInst* emitAdd(IRType* type, IRInst* left, IRInst* right); IRInst* emitSub(IRType* type, IRInst* left, IRInst* right); @@ -3207,11 +3234,21 @@ public: addDecoration(value, kIROp_ForwardDifferentiableDecoration); } + void addBackwardDifferentiableDecoration(IRInst* value) + { + addDecoration(value, kIROp_BackwardDifferentiableDecoration); + } + void addForwardDerivativeDecoration(IRInst* value, IRInst* fwdFunc) { addDecoration(value, kIROp_ForwardDerivativeDecoration, fwdFunc); } + void addBackwardDerivativeDecoration(IRInst* value, IRInst* jvpFn) + { + addDecoration(value, kIROp_BackwardDerivativeDecoration, jvpFn); + } + void addCOMWitnessDecoration(IRInst* value, IRInst* witnessTable) { addDecoration(value, kIROp_COMWitnessDecoration, &witnessTable, 1); diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 9d538a774..6112beaf2 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -3109,6 +3109,17 @@ namespace Slang return inst; } + IRInst* IRBuilder::emitBackwardDifferentiateInst(IRType* type, IRInst* baseFn) + { + auto inst = createInst<IRBackwardDifferentiate>( + this, + kIROp_BackwardDifferentiate, + type, + baseFn); + addInst(inst); + return inst; + } + IRInst* IRBuilder::emitMakeDifferentialPair(IRType* type, IRInst* primal, IRInst* differential) { IRInst* args[] = {primal, differential}; @@ -4556,6 +4567,17 @@ namespace Slang return inst; } + IRInst* IRBuilder::emitNeg(IRType* type, IRInst* value) + { + auto inst = createInst<IRInst>( + this, + kIROp_Neg, + type, + value); + addInst(inst); + return inst; + } + IRInst* IRBuilder::emitAdd(IRType* type, IRInst* left, IRInst* right) { auto inst = createInst<IRInst>( diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 5930875f1..a0158cf38 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -3081,6 +3081,21 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> baseVal.val)); } + // Emit IR to denote the forward-mode derivative + // of the inner func-expr. This will be resolved + // to a concrete function during the derivative + // pass. + LoweredValInfo visitBackwardDifferentiateExpr(BackwardDifferentiateExpr* expr) + { + auto baseVal = lowerSubExpr(expr->baseFunction); + SLANG_ASSERT(baseVal.flavor == LoweredValInfo::Flavor::Simple); + + return LoweredValInfo::simple( + getBuilder()->emitBackwardDifferentiateInst( + lowerType(context, expr->type), + baseVal.val)); + } + LoweredValInfo visitGetArrayLengthExpr(GetArrayLengthExpr* expr) { auto baseVal = lowerSubExpr(expr->arrayExpr); @@ -7799,6 +7814,10 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> { getBuilder()->addForwardDifferentiableDecoration(irFunc); } + if (decl->findModifier<BackwardDifferentiableAttribute>()) + { + getBuilder()->addBackwardDifferentiableDecoration(irFunc); + } if (auto differentialAttr = decl->findModifier<DifferentiableAttribute>()) { lowerDifferentiableAttribute(subContext, irFunc, differentialAttr); diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index 78edd4deb..d3dc5964e 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -2109,6 +2109,26 @@ namespace Slang return parseForwardDifferentiate(parser); } + /// Parse an expression of the form __bwd_diff(fn) where fn is an + /// identifier pointing to a function. + static Expr* parseBackwardDifferentiate(Parser* parser) + { + BackwardDifferentiateExpr* bwdDiffExpr = parser->astBuilder->create<BackwardDifferentiateExpr>(); + + parser->ReadToken(TokenType::LParent); + + bwdDiffExpr->baseFunction = parser->ParseExpression(); + + parser->ReadToken(TokenType::RParent); + + return bwdDiffExpr; + } + + static NodeBase* parseBackwardDifferentiate(Parser* parser, void* /* unused */) + { + return parseBackwardDifferentiate(parser); + } + /// Parse a `This` type expression static Expr* parseThisTypeExpr(Parser* parser) { @@ -6646,7 +6666,8 @@ namespace Slang _makeParseExpr("none", parseNoneExpr), _makeParseExpr("try", parseTryExpr), _makeParseExpr("__TaggedUnion", parseTaggedUnionType), - _makeParseExpr("__fwd_diff", parseForwardDifferentiate) + _makeParseExpr("__fwd_diff", parseForwardDifferentiate), + _makeParseExpr("__bwd_diff", parseBackwardDifferentiate) }; ConstArrayView<SyntaxParseInfo> getSyntaxParseInfos() diff --git a/source/slang/slang.natvis b/source/slang/slang.natvis index 38244f9eb..13334e00c 100644 --- a/source/slang/slang.natvis +++ b/source/slang/slang.natvis @@ -235,6 +235,7 @@ <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ExtractExistentialValueExpr">(Slang::ExtractExistentialValueExpr*)&astNodeType</ExpandedItem> <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::OpenRefExpr">(Slang::OpenRefExpr*)&astNodeType</ExpandedItem> <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ForwardDifferentiateExpr">(Slang::ForwardDifferentiateExpr*)&astNodeType</ExpandedItem> + <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::BackwardDifferentiateExpr">(Slang::BackwardDifferentiateExpr*)&astNodeType</ExpandedItem> <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::TaggedUnionTypeExpr">(Slang::TaggedUnionTypeExpr*)&astNodeType</ExpandedItem> <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::ThisTypeExpr">(Slang::ThisTypeExpr*)&astNodeType</ExpandedItem> <ExpandedItem Condition="astNodeType == Slang::ASTNodeType::AndTypeExpr">(Slang::AndTypeExpr*)&astNodeType</ExpandedItem> diff --git a/tests/autodiff/backward-diff-smoke.slang b/tests/autodiff/backward-diff-smoke.slang new file mode 100644 index 000000000..3d0d8970d --- /dev/null +++ b/tests/autodiff/backward-diff-smoke.slang @@ -0,0 +1,25 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +typedef DifferentialPair<float> dpfloat; +typedef float.Differential dfloat; + +[BackwardDifferentiable] +float test(float x, float y) +{ + return 2.0f * x + y * 4.0f - x; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + dpfloat dpa = dpfloat(2.0, 0.0); + dpfloat dpb = dpfloat(1.5, 0.0); + + __bwd_diff(test)(dpa, dpb, 1.0f); + outputBuffer[0] = dpa.d; // Expect: 1 + outputBuffer[1] = dpb.d; // Expect: 4 +} diff --git a/tests/autodiff/backward-diff-smoke.slang.expected.txt b/tests/autodiff/backward-diff-smoke.slang.expected.txt new file mode 100644 index 000000000..8b514833c --- /dev/null +++ b/tests/autodiff/backward-diff-smoke.slang.expected.txt @@ -0,0 +1,6 @@ +type: float +1.0 +4.0 +0.0 +0.0 +0.0
\ No newline at end of file |
