#include "slang-ir-autodiff-rev.h" #include "slang-ir-autodiff-cfg-norm.h" #include "slang-ir-autodiff-fwd.h" #include "slang-ir-clone.h" #include "slang-ir-dce.h" #include "slang-ir-dominators.h" #include "slang-ir-eliminate-multilevel-break.h" #include "slang-ir-init-local-var.h" #include "slang-ir-inline.h" #include "slang-ir-inst-pass-base.h" #include "slang-ir-loop-unroll.h" #include "slang-ir-redundancy-removal.h" #include "slang-ir-single-return.h" #include "slang-ir-ssa-simplification.h" #include "slang-ir-util.h" namespace Slang { IRFuncType* BackwardDiffTranscriberBase::differentiateFunctionTypeImpl( IRBuilder* builder, IRFuncType* funcType, IRInst* intermeidateType) { List newParameterTypes; IRType* diffReturnType; for (UIndex i = 0; i < funcType->getParamCount(); i++) { auto origType = funcType->getParamType(i); auto paramType = transcribeParamTypeForPropagateFunc(builder, origType); if (paramType) newParameterTypes.add(paramType); } if (auto diffResultType = differentiateType(builder, funcType->getResultType())) newParameterTypes.add(diffResultType); if (intermeidateType) { newParameterTypes.add((IRType*)intermeidateType); } diffReturnType = builder->getVoidType(); return builder->getFuncType(newParameterTypes, diffReturnType); } IRFuncType* BackwardDiffPrimalTranscriber::differentiateFunctionType( IRBuilder* builder, IRInst* func, IRFuncType* funcType) { IRType* intermediateType = builder->getBackwardDiffIntermediateContextType(maybeFindOuterGeneric(func)); if (auto outerGeneric = findOuterGeneric(builder->getInsertLoc().getParent())) { intermediateType = (IRType*)specializeWithGeneric(*builder, intermediateType, as(outerGeneric)); } auto outType = builder->getOutParamType(intermediateType); List paramTypes; for (UInt i = 0; i < funcType->getParamCount(); i++) { auto origType = funcType->getParamType(i); auto primalType = transcribeParamTypeForPrimalFunc(builder, origType); paramTypes.add(primalType); } paramTypes.add(outType); IRFuncType* primalFuncType = builder->getFuncType( paramTypes, (IRType*)findOrTranscribePrimalInst(builder, funcType->getResultType())); return primalFuncType; } InstPair BackwardDiffPrimalTranscriber::transcribeFunc( IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffFunc) { // Don't need to do anything other than add a decoration in the original func to point to the // primal func. The body of the primal func will be generated by propagateTranscriber together // with propagate func. addTranscribedFuncDecoration(*builder, primalFunc, diffFunc); builder->addDecoration(diffFunc, kIROp_IgnoreSideEffectsDecoration); return InstPair(primalFunc, diffFunc); } static List _defineFuncParams(IRBuilder* builder, IRFunc* func) { auto propFuncType = cast(func->getFullType()); List params; for (UInt i = 0; i < propFuncType->getParamCount(); i++) { auto paramType = propFuncType->getParamType(i); auto param = builder->emitParam(paramType); params.add(param); } return params; } void BackwardDiffPropagateTranscriber::generateTrivialDiffFuncFromUserDefinedDerivative( IRBuilder* builder, IRFunc* originalFunc, IRFunc* diffPropFunc, IRUserDefinedBackwardDerivativeDecoration* udfDecor) { // Create an empty struct type to use as the intermediate context type. auto originalGeneric = findOuterGeneric(originalFunc); builder->setInsertBefore(originalFunc); IRInst* emptyStruct = builder->createStructType(); IRInst* emptyStructType = nullptr; auto emptyStructGeneric = hoistValueFromGeneric(*builder, emptyStruct, emptyStructType, false); builder->addBackwardDerivativeIntermediateTypeDecoration(originalFunc, emptyStructGeneric); IRInst* udf = udfDecor->getBackwardDerivativeFunc(); builder->setInsertInto(diffPropFunc); builder->emitBlock(); List params = _defineFuncParams(builder, diffPropFunc); params.removeLast(); IRInst* udfRefFromPropFunc = udf; if (auto specialize = as(udf)) { udf = specialize->getBase(); auto propGeneric = findOuterGeneric(diffPropFunc); SLANG_RELEASE_ASSERT(propGeneric); udfRefFromPropFunc = maybeSpecializeWithGeneric(*builder, udf, propGeneric); } builder->emitCallInst(builder->getVoidType(), udfRefFromPropFunc, params); builder->emitReturn(); // Copy other decorations from the original func to the generated primal func wrapper. copyOriginalDecorations(udf, diffPropFunc); // Now create the trivial primal function. auto existingDecor = originalFunc->findDecoration(); if (!existingDecor) { // We haven't created a header for primal func yet, create it now. if (originalGeneric) builder->setInsertBefore(originalGeneric); else builder->setInsertBefore(originalFunc); autoDiffSharedContext->transcriberSet.primalTranscriber->transcribe( builder, originalGeneric ? originalGeneric : originalFunc); existingDecor = originalFunc->findDecoration(); } SLANG_RELEASE_ASSERT(existingDecor); // Fill the primal func header with trivial call to original func. IRInst* existingPrimalFunc = existingDecor->getBackwardDerivativePrimalFunc(); IRGeneric* existingPriamlFuncGeneric = nullptr; if (auto specialize = as(existingPrimalFunc)) { existingPriamlFuncGeneric = as(specialize->getBase()); existingPrimalFunc = findGenericReturnVal(existingPriamlFuncGeneric); } builder->setInsertBefore(existingPrimalFunc); builder->setInsertInto(existingPrimalFunc); auto checkpointHint = udf->findDecoration(); if (!checkpointHint) checkpointHint = originalFunc->findDecoration(); if (checkpointHint) cloneCheckpointHint( builder, checkpointHint, cast(existingPrimalFunc)); // Copy other decorations from the original func to the generated primal func wrapper. copyOriginalDecorations(udf, existingPrimalFunc); builder->emitBlock(); params = _defineFuncParams(builder, as(existingPrimalFunc)); params.removeLast(); // Unwrap any ref pairs. We need this special case for trivial funcs. for (Int i = 0; i < params.getCount(); i++) { if (as(params[i]->getDataType())) { params[i] = builder->emitDifferentialPtrPairGetPrimal(params[i]); } } IRInst* originalFuncRefFromPrimalFunc = originalFunc; if (originalGeneric) originalFuncRefFromPrimalFunc = maybeSpecializeWithGeneric(*builder, originalGeneric, existingPriamlFuncGeneric); auto result = builder->emitCallInst( cast(existingPrimalFunc->getFullType())->getResultType(), originalFuncRefFromPrimalFunc, params); builder->emitReturn(result); } IRFuncType* BackwardDiffPropagateTranscriber::differentiateFunctionType( IRBuilder* builder, IRInst* func, IRFuncType* funcType) { IRType* intermediateType = nullptr; if (auto outerGeneric = findOuterGeneric(builder->getInsertLoc().getParent())) { intermediateType = builder->getBackwardDiffIntermediateContextType(maybeFindOuterGeneric(func)); intermediateType = (IRType*)specializeWithGeneric(*builder, intermediateType, as(outerGeneric)); } else if (as(func)) { intermediateType = nullptr; } else { intermediateType = builder->getBackwardDiffIntermediateContextType(maybeFindOuterGeneric(func)); } return differentiateFunctionTypeImpl(builder, funcType, intermediateType); } IRFuncType* BackwardDiffTranscriber::differentiateFunctionType( IRBuilder* builder, IRInst* func, IRFuncType* funcType) { SLANG_UNUSED(func); return differentiateFunctionTypeImpl(builder, funcType, nullptr); } InstPair BackwardDiffPropagateTranscriber::transcribeFunc( IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffFunc) { addTranscribedFuncDecoration(*builder, primalFunc, diffFunc); if (auto udf = primalFunc->findDecoration()) { generateTrivialDiffFuncFromUserDefinedDerivative(builder, primalFunc, diffFunc, udf); } else { transcribeFuncImpl(builder, primalFunc, diffFunc); } return InstPair(primalFunc, diffFunc); } InstPair BackwardDiffTranscriberBase::transcribeInstImpl(IRBuilder* builder, IRInst* origInst) { switch (origInst->getOp()) { case kIROp_Param: return transcribeParam(builder, as(origInst)); case kIROp_Return: return transcribeReturn(builder, as(origInst)); case kIROp_LookupWitnessMethod: return transcribeLookupInterfaceMethod(builder, as(origInst)); case kIROp_Specialize: return transcribeSpecialize(builder, as(origInst)); case kIROp_MakeTuple: case kIROp_FloatLit: case kIROp_IntLit: case kIROp_VoidLit: case kIROp_ExtractExistentialWitnessTable: case kIROp_ExtractExistentialType: case kIROp_ExtractExistentialValue: case kIROp_WrapExistential: case kIROp_MakeExistential: case kIROp_MakeExistentialWithRTTI: case kIROp_DebugInlinedAt: case kIROp_DebugScope: case kIROp_DebugNoScope: case kIROp_DebugInlinedVariable: case kIROp_DebugFunction: return transcribeNonDiffInst(builder, origInst); case kIROp_StructKey: return InstPair(origInst, nullptr); } return InstPair(nullptr, nullptr); } // Returns "dp" to use as a name hint for parameters. // If no primal name is available, returns a blank string. // String BackwardDiffTranscriberBase::makeDiffPairName(IRInst* origVar) { if (auto namehintDecoration = origVar->findDecoration()) { return ("dp" + String(namehintDecoration->getName())); } return String(""); } static IRType* _getPrimalTypeFromNoDiffType( BackwardDiffTranscriberBase* transcriber, IRBuilder* builder, IRType* origType) { IRType* valueType = origType; auto ptrType = as(valueType); if (ptrType) valueType = ptrType->getValueType(); if (auto attrType = as(valueType)) { if (attrType->findAttr()) { auto primalValueType = (IRType*)transcriber->findOrTranscribePrimalInst(builder, valueType); if (ptrType) return builder->getPtrType(ptrType->getOp(), primalValueType); return primalValueType; } } return nullptr; } IRType* BackwardDiffTranscriberBase::transcribeParamTypeForPrimalFunc( IRBuilder* builder, IRType* paramType) { // If the param is marked as no_diff, return the primal type. if (auto primalNoDiffType = _getPrimalTypeFromNoDiffType(this, builder, paramType)) return primalNoDiffType; auto primalType = (IRType*)findOrTranscribePrimalInst(builder, paramType); // Differentiable pointer types are treated as primal pairs, since they aren't involved in the // transposition process. // if (differentiableTypeConformanceContext.isDifferentiablePtrType(primalType)) { auto diffPairType = tryGetDiffPairType(builder, primalType); SLANG_ASSERT(diffPairType); return diffPairType; } return primalType; } IRType* BackwardDiffTranscriberBase::transcribeParamTypeForPropagateFunc( IRBuilder* builder, IRType* paramType) { if (auto outType = as(paramType)) { auto valueType = outType->getValueType(); auto diffValueType = differentiateType(builder, valueType); return diffValueType; } auto maybeConvertInOutTypeToValueType = [](IRType* type) { if (auto inoutType = as(type)) return inoutType->getValueType(); return type; }; // If the param is marked as no_diff, return the primal type. if (auto primalNoDiffType = _getPrimalTypeFromNoDiffType(this, builder, paramType)) return maybeConvertInOutTypeToValueType(primalNoDiffType); auto diffPairType = tryGetDiffPairType(builder, paramType); if (diffPairType) { if (!asRelevantPtrType(diffPairType) && !as(diffPairType)) return builder->getBorrowInOutParamType(diffPairType); return diffPairType; } auto primalType = (IRType*)findOrTranscribePrimalInst(builder, paramType); return maybeConvertInOutTypeToValueType(primalType); } // Create an empty func to represent the transcribed func of `origFunc`. InstPair BackwardDiffTranscriberBase::transcribeFuncHeaderImpl( IRBuilder* inBuilder, IRFunc* origFunc) { if (!isBackwardDifferentiableFunc(origFunc) && !origFunc->findDecoration()) return InstPair(nullptr, nullptr); IRBuilder builder = *inBuilder; IRFunc* primalFunc = origFunc; maybeMigrateDifferentiableDictionaryFromDerivativeFunc(inBuilder, origFunc); differentiableTypeConformanceContext.setFunc(origFunc); auto diffFunc = builder.createFunc(); SLANG_ASSERT(as(origFunc->getFullType())); builder.setInsertBefore(diffFunc); IRType* diffFuncType = this->differentiateFunctionType( &builder, origFunc, as(origFunc->getFullType())); diffFunc->setFullType(diffFuncType); if (origFunc->findDecoration()) { auto newName = this->getTranscribedFuncName(&builder, origFunc); builder.addNameHintDecoration(diffFunc, newName); } addTranscribedFuncDecoration(builder, primalFunc, diffFunc); // Transfer checkpoint hint decorations copyCheckpointHints(&builder, origFunc, diffFunc); // Mark the generated derivative function itself as differentiable. builder.addBackwardDifferentiableDecoration(diffFunc); copyOriginalDecorations(origFunc, diffFunc); builder.addFloatingModeOverrideDecoration(diffFunc, FloatingPointMode::Fast); return InstPair(primalFunc, diffFunc); } void BackwardDiffTranscriberBase::addTranscribedFuncDecoration( IRBuilder& builder, IRFunc* origFunc, IRFunc* transcribedFunc) { IRBuilder subBuilder = builder; if (auto outerGen = findOuterGeneric(transcribedFunc)) { subBuilder.setInsertBefore(origFunc); auto specialized = specializeWithGeneric(subBuilder, outerGen, as(findOuterGeneric(origFunc))); addExistingDiffFuncDecor(&subBuilder, origFunc, specialized); } else { addExistingDiffFuncDecor(&subBuilder, origFunc, transcribedFunc); } } InstPair BackwardDiffTranscriberBase::transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc) { InstPair result; // If we're transcribing a function as a 'value' (i.e. maybe embedded in a generic, keep the // insert location unchanges). If we're transcribing it as a declaration, we should // insert into the module. // auto origOuterGen = as(findOuterGeneric(origFunc)); if (!origOuterGen || !(findInnerMostGenericReturnVal(origOuterGen) == origFunc)) { // Dealing with a declaration.. insert into module scope. IRBuilder subBuilder = *inBuilder; subBuilder.setInsertInto(inBuilder->getModule()); result = transcribeFuncHeaderImpl(&subBuilder, origFunc); } else { result = transcribeFuncHeaderImpl(inBuilder, origFunc); } FuncBodyTranscriptionTask task; task.originalFunc = as(result.primal); task.resultFunc = as(result.differential); task.type = diffTaskType; if (task.resultFunc) { autoDiffSharedContext->followUpFunctionsToTranscribe.add(task); } return result; } InstPair BackwardDiffTranscriber::transcribeFuncHeader(IRBuilder* inBuilder, IRFunc* origFunc) { if (auto bwdDiffFunc = findExistingDiffFunc(origFunc)) return InstPair(origFunc, bwdDiffFunc); auto header = transcribeFuncHeaderImpl(inBuilder, origFunc); if (!header.differential) return header; IRBuilder builder = *inBuilder; builder.setInsertInto(header.differential); builder.emitBlock(); auto origFuncType = as(origFunc->getFullType()); List primalArgs, propagateArgs; List primalTypes, propagateTypes; IRType* primalResultType = transcribeParamTypeForPrimalFunc(&builder, origFuncType->getResultType()); IRParam* currentParam = origFunc->getFirstParam(); for (UInt i = 0; i < origFuncType->getParamCount(); i++) { IRBuilderSourceLocRAII sourceLocationScope(&builder, currentParam->sourceLoc); auto primalParamType = transcribeParamTypeForPrimalFunc(&builder, origFuncType->getParamType(i)); auto propagateParamType = transcribeParamTypeForPropagateFunc(&builder, origFuncType->getParamType(i)); if (propagateParamType) { auto param = builder.emitParam(propagateParamType); propagateTypes.add(propagateParamType); propagateArgs.add(param); // Fetch primal values to use as arguments in primal func call. IRInst* primalArg = param; if (!as(primalParamType) && !as(primalParamType)) { // As long as the primal parameter is not an out or constref type, // we need to fetch the primal value from the parameter. if (asRelevantPtrType(propagateParamType)) { primalArg = builder.emitLoad(param); } if (const auto diffPairType = as(primalArg->getDataType())) { primalArg = builder.emitDifferentialPairGetPrimal(primalArg); } } if (auto primalParamPtrType = isMutablePointerType(primalParamType)) { // If primal parameter is mutable, we need to pass in a temp var. auto tempVar = builder.emitVar(primalParamPtrType->getValueType()); // If the parameter is not a pure 'out' param, we also need to setup the initial // value of the temp var, otherwise the temp var will be uninitialized which could // cause undefined behavior in the primal function. // if (!as(primalParamType)) builder.emitStore(tempVar, primalArg); primalArgs.add(tempVar); } else { primalArgs.add(primalArg); } } else { auto primalPtrType = asRelevantPtrType(primalParamType); SLANG_RELEASE_ASSERT(primalPtrType); auto primalValueType = primalPtrType->getValueType(); auto var = builder.emitVar(primalValueType); primalArgs.add(var); } primalTypes.add(primalParamType); currentParam = currentParam->getNextParam(); } // Add dOut argument to propagateArgs. auto diffResultType = differentiateType(&builder, origFunc->getResultType()); if (diffResultType) { auto param = builder.emitParam(diffResultType); propagateArgs.add(param); propagateTypes.add(param->getFullType()); } auto outerGeneric = findOuterGeneric(origFunc); IRType* intermediateType = builder.getBackwardDiffIntermediateContextType(maybeFindOuterGeneric(origFunc)); IRInst* specializedOriginalFunc = origFunc; if (outerGeneric) { specializedOriginalFunc = maybeSpecializeWithGeneric( builder, outerGeneric, findOuterGeneric(header.differential)); intermediateType = (IRType*)specializeWithGeneric( builder, intermediateType, as(findOuterGeneric(header.differential))); } auto intermediateVar = builder.emitVar(intermediateType); auto primalFuncType = builder.getFuncType(primalTypes, primalResultType); primalArgs.add(intermediateVar); primalTypes.add(builder.getOutParamType(intermediateType)); auto primalFunc = builder.emitBackwardDifferentiatePrimalInst(primalFuncType, specializedOriginalFunc); builder.emitCallInst(primalResultType, primalFunc, primalArgs); propagateTypes.add(intermediateType); propagateArgs.add(builder.emitLoad(intermediateVar)); auto propagateFuncType = builder.getFuncType(propagateTypes, builder.getVoidType()); auto propagateFunc = builder.emitBackwardDifferentiatePropagateInst(propagateFuncType, specializedOriginalFunc); builder.emitCallInst(builder.getVoidType(), propagateFunc, propagateArgs); builder.emitReturn(); addTranscribedFuncDecoration(builder, origFunc, cast(header.differential)); return header; } // Puts parameters into their own block. void BackwardDiffTranscriberBase::makeParameterBlock(IRBuilder* inBuilder, IRFunc* func) { IRBuilder builder = *inBuilder; auto firstBlock = func->getFirstBlock(); IRParam* param = func->getFirstParam(); builder.setInsertBefore(firstBlock); // Note: It looks like emitBlock() doesn't use the current // builder position, so we're going to manually move the new block // to before the existing block. auto paramBlock = builder.emitBlock(); paramBlock->insertBefore(firstBlock); builder.setInsertInto(paramBlock); while (param) { IRParam* nextParam = param->getNextParam(); // Move inst into the new parameter block. param->insertAtEnd(paramBlock); param = nextParam; } // Replace this block as the first block. firstBlock->replaceUsesWith(paramBlock); // Add terminator inst. builder.emitBranch(firstBlock); } SlangResult BackwardDiffTranscriberBase::prepareFuncForBackwardDiff(IRFunc* func) { removeLinkageDecorations(func); performPreAutoDiffForceInlining(func); DifferentiableTypeConformanceContext diffTypeContext(autoDiffSharedContext); diffTypeContext.setFunc(func); auto returnCount = getReturnCount(func); if (returnCount > 1) { convertFuncToSingleReturnForm(func->getModule(), func); } else if (returnCount == 0) { // The function is ill-formed and never returns (such as having an infinite loop), // we can't possibly reverse-differentiate such functions, so we will diagnose it here. getSink()->diagnose(func->sourceLoc, Diagnostics::functionNeverReturnsFatal, func); } eliminateContinueBlocksInFunc(func->getModule(), func); eliminateMultiLevelBreakForFunc(func->getModule(), func); IRCFGNormalizationPass cfgPass = {this->getSink()}; normalizeCFG(autoDiffSharedContext->moduleInst->getModule(), func, cfgPass); return SLANG_OK; } // Create a copy of originalFunc's forward derivative in the same generic context (if any) of // `diffPropagateFunc`. IRFunc* BackwardDiffTranscriberBase::generateNewForwardDerivativeForFunc( IRBuilder* builder, IRFunc* originalFunc, IRFunc* diffPropagateFunc) { auto primalOuterParent = findOuterGeneric(originalFunc); if (!primalOuterParent) primalOuterParent = originalFunc; // Make a clone of original func so we won't modify the original. IRCloneEnv originalCloneEnv; primalOuterParent = cloneInst(&originalCloneEnv, builder, primalOuterParent); auto primalFunc = as(getGenericReturnVal(primalOuterParent)); // Strip any existing derivative decorations off the clone. stripDerivativeDecorations(primalFunc); eliminateDeadCode(primalOuterParent); // Perform required transformations and simplifications on the original func to make it // reversible. if (SLANG_FAILED(prepareFuncForBackwardDiff(primalFunc))) return diffPropagateFunc; // Forward transcribe the clone of the original func. ForwardDiffTranscriber& fwdTranscriber = *static_cast( autoDiffSharedContext->transcriberSet.forwardTranscriber); auto oldCount = autoDiffSharedContext->followUpFunctionsToTranscribe.getCount(); IRFunc* fwdDiffFunc = as(getGenericReturnVal(fwdTranscriber.transcribe(builder, primalOuterParent))); fwdDiffFunc->sourceLoc = primalFunc->sourceLoc; SLANG_ASSERT(fwdDiffFunc); auto newCount = autoDiffSharedContext->followUpFunctionsToTranscribe.getCount(); for (auto i = oldCount; i < newCount; i++) { auto pendingTask = autoDiffSharedContext->followUpFunctionsToTranscribe.getLast(); autoDiffSharedContext->followUpFunctionsToTranscribe.removeLast(); SLANG_RELEASE_ASSERT(pendingTask.type == FuncBodyTranscriptionTaskType::Forward); fwdTranscriber.transcribeFunc(builder, pendingTask.originalFunc, pendingTask.resultFunc); } // Remove the clone of original func. primalOuterParent->removeAndDeallocate(); // Remove redundant loads since they interfere with transposition logic. eliminateRedundantLoadStore(fwdDiffFunc); // Migrate the new forward derivative function into the generic parent of `diffPropagateFunc`. if (auto fwdParentGeneric = as(findOuterGeneric(fwdDiffFunc))) { // Clone forward derivative func from its own generic into current generic parent. GenericChildrenMigrationContext migrationContext; auto diffOuterGeneric = as(findOuterGeneric(diffPropagateFunc)); SLANG_RELEASE_ASSERT(diffOuterGeneric); migrationContext.init(fwdParentGeneric, diffOuterGeneric, diffPropagateFunc); auto inst = fwdParentGeneric->getFirstBlock()->getFirstOrdinaryInst(); builder->setInsertBefore(diffPropagateFunc); while (inst) { auto next = inst->getNextInst(); auto cloned = migrationContext.cloneInst(builder, inst); if (inst == fwdDiffFunc) { fwdDiffFunc = as(cloned); break; } inst = next; } fwdParentGeneric->removeAndDeallocate(); } return fwdDiffFunc; } InstPair BackwardDiffTranscriberBase::transcribeFuncParam( IRBuilder* builder, IRParam* origParam, IRInst* primalType) { SLANG_UNUSED(primalType); SLANG_RELEASE_ASSERT( origParam->getParent() && origParam->getParent()->getParent() && origParam->getParent()->getParent()->getOp() == kIROp_Generic); auto primalInst = maybeCloneForPrimalInst(builder, origParam); if (auto primalParam = as(primalInst)) { SLANG_RELEASE_ASSERT(builder->getInsertLoc().getBlock()); primalParam->removeFromParent(); builder->getInsertLoc().getBlock()->addParam(primalParam); } return InstPair(primalInst, nullptr); } // Keep primal param replacement insts alive during DCE. static void _lockPrimalParamReplacementInsts( IRBuilder* builder, ParameterBlockTransposeInfo& paramInfo) { for (auto& kv : paramInfo.mapPrimalSpecificParamToReplacementInPropFunc) builder->addKeepAliveDecoration(kv.value); } // Remove [KeepAlive] decorations for primal param replacement insts. static void _unlockPrimalParamReplacementInsts(ParameterBlockTransposeInfo& paramInfo) { for (const auto& [_, value] : paramInfo.mapPrimalSpecificParamToReplacementInPropFunc) value->findDecoration()->removeAndDeallocate(); } // Transcribe a function definition. void BackwardDiffTranscriberBase::transcribeFuncImpl( IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffPropagateFunc) { SLANG_ASSERT(primalFunc); SLANG_ASSERT(diffPropagateFunc); // Reverse-mode transcription uses 4 separate steps: // TODO(sai): Fill in documentation. // Generate a temporary forward derivative function as an intermediate step. IRBuilder tempBuilder = *builder; if (auto outerGeneric = findOuterGeneric(diffPropagateFunc)) { tempBuilder.setInsertBefore(outerGeneric); } else { tempBuilder.setInsertBefore(diffPropagateFunc); } auto fwdDiffFunc = generateNewForwardDerivativeForFunc(&tempBuilder, primalFunc, diffPropagateFunc); if (!fwdDiffFunc) return; bool isResultDifferentiable = as(fwdDiffFunc->getResultType()); // Split first block into a paramter block. this->makeParameterBlock(&tempBuilder, as(fwdDiffFunc)); // This steps adds a decoration to instructions that are computing the differential. // TODO: This is disabled for now because fwd-mode already adds differential decorations // wherever need. We need to run this pass only for user-writted forward derivativecode. // // diffPropagationPass->propagateDiffInstDecoration(builder, fwdDiffFunc); diffUnzipPass->unzipDiffInsts(fwdDiffFunc); IRFunc* unzippedFwdDiffFunc = fwdDiffFunc; // Move blocks from `unzippedFwdDiffFunc` to the `diffPropagateFunc` shell. builder->setInsertInto(diffPropagateFunc->getParent()); { List workList; for (auto block = unzippedFwdDiffFunc->getFirstBlock(); block; block = block->getNextBlock()) workList.add(block); for (auto block : workList) block->insertAtEnd(diffPropagateFunc); } // Transpose the first block (parameter block) auto paramTransposeInfo = splitAndTransposeParameterBlock( builder, diffPropagateFunc, primalFunc->sourceLoc, isResultDifferentiable); // The insts we inserted in paramTransposeInfo.mapPrimalSpecificParamToReplacementInPropFunc // may be used by write back logic that we are going to insert later. // Before then we want to keep them alive. _lockPrimalParamReplacementInsts(builder, paramTransposeInfo); builder->setInsertInto(diffPropagateFunc); // Transpose differential blocks from unzippedFwdDiffFunc into diffFunc (with dOutParameter) // representing the derivative of the return value. DiffTransposePass::FuncTranspositionInfo transposeInfo = {paramTransposeInfo.dOutParam}; diffTransposePass->transposeDiffBlocksInFunc(diffPropagateFunc, transposeInfo); // Apply checkpointing policy to legalize cross-scope uses of primal values // using either recompute or store strategies. auto primalsInfo = applyCheckpointPolicy(diffPropagateFunc); eliminateDeadCode(diffPropagateFunc); // Extracts the primal computations into its own func, turn all accesses to stored primal insts // into explicit intermediate data structure reads and writes. IRInst* intermediateType = nullptr; auto extractedPrimalFunc = diffUnzipPass->extractPrimalFunc( diffPropagateFunc, primalFunc, primalsInfo, paramTransposeInfo, intermediateType); // At this point the unzipped func is just an empty shell // and we can simply remove it. unzippedFwdDiffFunc->removeAndDeallocate(); // Write back derivatives to inout parameters. writeBackDerivativeToInOutParams(paramTransposeInfo, diffPropagateFunc); // Remove primalFunc specific params. List paramsToRemove; for (auto param : diffPropagateFunc->getParams()) { if (!paramTransposeInfo.propagateFuncParams.contains(param)) paramsToRemove.add(param); } for (auto param : paramsToRemove) { if (param->hasUses()) { IRInst* replacement = nullptr; paramTransposeInfo.mapPrimalSpecificParamToReplacementInPropFunc.tryGetValue( param, replacement); SLANG_RELEASE_ASSERT(replacement); param->replaceUsesWith(replacement); } param->removeAndDeallocate(); } _unlockPrimalParamReplacementInsts(paramTransposeInfo); // If primal function is nested in a generic, we want to create separate generics for all the // associated things we have just created. auto primalOuterGeneric = findOuterGeneric(primalFunc); IRInst* specializedFunc = nullptr; auto intermediateTypeGeneric = hoistValueFromGeneric(*builder, intermediateType, specializedFunc, true); builder->setInsertBefore(primalFunc); builder->addBackwardDerivativeIntermediateTypeDecoration(primalFunc, intermediateTypeGeneric); auto primalFuncGeneric = hoistValueFromGeneric(*builder, extractedPrimalFunc, specializedFunc, true); builder->setInsertBefore(primalFunc); // Copy over checkpoint preference hints. { auto diffPrimalFunc = getResolvedInstForDecorations(primalFuncGeneric, true); auto checkpointHint = primalFunc->findDecoration(); if (checkpointHint) builder->addDecoration(diffPrimalFunc, checkpointHint->getOp()); } if (auto existingDecor = primalFunc->findDecoration()) { // If we already created a header for primal func, move the body into the existing primal // func header. auto existingPrimalHeader = existingDecor->getBackwardDerivativePrimalFunc(); if (auto spec = as(existingPrimalHeader)) existingPrimalHeader = spec->getBase(); moveInstChildren(existingPrimalHeader, primalFuncGeneric); primalFuncGeneric->replaceUsesWith(existingPrimalHeader); primalFuncGeneric->removeAndDeallocate(); primalFuncGeneric = existingPrimalHeader; } else { auto specializedBackwardPrimalFunc = maybeSpecializeWithGeneric(*builder, primalFuncGeneric, primalOuterGeneric); builder->addBackwardDerivativePrimalDecoration(primalFunc, specializedBackwardPrimalFunc); } initializeLocalVariables( builder->getModule(), as(getGenericReturnVal(primalFuncGeneric))); initializeLocalVariables(builder->getModule(), diffPropagateFunc); stripTempDecorations(diffPropagateFunc); sortBlocksInFunc(diffPropagateFunc); sortBlocksInFunc(primalFunc); } ParameterBlockTransposeInfo BackwardDiffTranscriberBase::splitAndTransposeParameterBlock( IRBuilder* builder, IRFunc* diffFunc, SourceLoc primalLoc, bool isResultDifferentiable) { // This method splits transposes the all the parameters for both the primal and propagate // computation. At the end of this method, the parameter block will contain a combination of // parameters for both the to-be-primal function and to-be-propagate function. We use // ParameterBlockTransposeInfo::primalFuncParams and // ParameterBlockTransposeInfo::propagateFuncParams to track which parameters are dedicated to // the future primal or propagate func. A later step will then split the parameters out to each // new function. ParameterBlockTransposeInfo result; // First, we initialize the IR builders and locate the import code insertion points that will // be used for the rest of this method. IRBlock* fwdDiffParameterBlock = diffFunc->getFirstBlock(); // Find the 'next' block using the terminator inst of the parameter block. auto fwdParamBlockBranch = as(fwdDiffParameterBlock->getTerminator()); // We create a new block after parameter block to hold insts that translates from transposed // parameters into something that the rest of the function can use. IRBuilder::insertBlockAlongEdge(diffFunc->getModule(), IREdge(&fwdParamBlockBranch->block)); auto paramPreludeBlock = fwdParamBlockBranch->getTargetBlock(); auto nextBlockBuilder = *builder; nextBlockBuilder.setInsertBefore(paramPreludeBlock->getFirstOrdinaryInst()); SourceLoc returnLoc; IRBlock* firstDiffBlock = nullptr; for (auto block : diffFunc->getBlocks()) { if (isDifferentialInst(block)) { firstDiffBlock = block; break; } auto terminator = block->getTerminator(); if (as(terminator)) { returnLoc = terminator->sourceLoc; break; } } SLANG_RELEASE_ASSERT(firstDiffBlock); auto diffBuilder = *builder; diffBuilder.setInsertBefore(firstDiffBlock->getFirstOrdinaryInst()); builder->setInsertBefore(fwdParamBlockBranch); // Collect all the original parameters. List fwdParams; for (auto param : diffFunc->getParams()) fwdParams.add(param); // Maintain a set for insts pending removal. OrderedHashSet instsToRemove; // Now we begin the actual processing. // The first step is to transcribe all the existing parameters from the original function. // There are many cases to handle, including different combinations of parameter directions and // whether or not the parameter is differentiable. // To normalize the process for all these cases, we determine the following actions for each // parameter: // 1. Should this original parameter be translated to a parameter in the primal func and the // propagate func? // if so, we emit a param inst representing the final parameter for that func. If the // parameter should be mapped to both the primal func and the propagate func, we will emit // two separate params with their final type. // 2. If this parameter has a corresponding primal func parameter, we replace all uses of the // original // parameter in the primal computation code to the new primal parameter. If any // initialization logic is needed to convert the type of the new primal parameter to what the // code was expecting, we insert that code in the first block. // 3. If this parameter has a correponding propagate func parameter, we replace all uses of the // original parameter // in the diff computation code to the new propagate parameter. We insert necessary // initialization diff block or the first block depending on whether we want that logic go // through the transposition pass. We may need to replace the uses to different // values/variables depending on whether that use is a read or write. // 4. If the parameter has both corresponding primal and propagate parameters, we also need to // consider // how the future propagate function access the primal parameter. We will insert necessary // preparation code that constructs temp vars or values to replace the primal parameter after // we remove it from the propagate func. // Base on above discussion, we need to compute the following values for each parameter: // - diffRefReplacement. What should all read(load) references to this parameter from // differential code be replaced to. // - diffRefWriteReplacement. What should all write references to this parameter from // differential code be replaced to. // - primalRefReplacement. What should all references to this parameter from primal code be // replaced to. // - mapPrimalSpecificParamToReplacementInPropFunc[param]. What should all references to this // parameter // from the primal compuation logic in the future propagate function be replaced to. for (auto fwdParam : fwdParams) { IRBuilderSourceLocRAII sourceLocationScope(builder, fwdParam->sourceLoc); // Define the replacement insts that we are going to fill in for each case. IRInst* diffRefReplacement = nullptr; IRInst* primalRefReplacement = nullptr; IRInst* diffWriteRefReplacement = nullptr; // Common logic that computes all the important types we care about. IRDifferentialPairType* diffPairType = as(fwdParam->getDataType()); auto inoutType = as(fwdParam->getDataType()); auto outType = as(fwdParam->getDataType()); if (inoutType) diffPairType = as(inoutType->getValueType()); else if (outType) diffPairType = as(outType->getValueType()); IRType* primalType = nullptr; IRType* diffType = nullptr; if (diffPairType) { primalType = diffPairType->getValueType(); diffType = (IRType*)differentiableTypeConformanceContext.getDiffTypeFromPairType( builder, diffPairType); } // Now we handle each combination of parameter direction x differentiability. if (outType) { // Case 1: out parameters. // Out parameters need to be handled differently whether or not it is differentiable, // since the propagate function will not have a corresponding output. if (diffPairType) { // Create dOut param. auto diffParam = builder->emitParam(diffType); copyNameHintAndDebugDecorations(diffParam, fwdParam); result.propagateFuncParams.add(diffParam); primalRefReplacement = builder->emitParam(builder->getOutParamType(primalType)); copyNameHintAndDebugDecorations(primalRefReplacement, fwdParam); // Create a local var for read access in pre-transpose code. // This will the var from which we will fetch the final resulting derivative // after transposition. auto tempVar = nextBlockBuilder.emitVar(diffType); copyNameHintAndDebugDecorations(tempVar, fwdParam); result.propagateFuncSpecificPrimalInsts.add(tempVar); // Initialize the var with input diff param at start. // Note that we insert the store in the primal block so it won't get transposed. auto storeInst = nextBlockBuilder.emitStore(tempVar, diffParam); nextBlockBuilder.markInstAsDifferential(storeInst, primalType); // Since this store inst is specific to propagate function, we track it in a // set so we can remove it when we generate the primal func. result.propagateFuncSpecificPrimalInsts.add(storeInst); diffWriteRefReplacement = tempVar; diffRefReplacement = tempVar; } else { primalRefReplacement = builder->emitParam(outType); copyNameHintAndDebugDecorations(primalRefReplacement, fwdParam); } result.primalFuncParams.add(primalRefReplacement); // Create a local var for the out param for the primal part of the prop func. auto tempPrimalVar = nextBlockBuilder.emitVar(outType->getValueType()); copyNameHintAndDebugDecorations(tempPrimalVar, fwdParam); result.mapPrimalSpecificParamToReplacementInPropFunc[primalRefReplacement] = tempPrimalVar; instsToRemove.add(fwdParam); } else if (!isRelevantDifferentialPair(fwdParam->getDataType())) { if (inoutType) { // Case 2: non differentiable inout parameter. // They should become an inout parameter in primal func, but an in parameter in // bwd func. fwdParam->removeFromParent(); fwdDiffParameterBlock->addParam(fwdParam); result.primalFuncParams.add(fwdParam); primalRefReplacement = fwdParam; // Create an in param for the prop func. auto propParam = builder->emitParam(inoutType->getValueType()); copyNameHintAndDebugDecorations(propParam, fwdParam); result.propagateFuncParams.add(propParam); // Create a local var for the out param for the primal part of the prop func. auto tempPrimalVar = nextBlockBuilder.emitVar(inoutType->getValueType()); copyNameHintAndDebugDecorations(tempPrimalVar, fwdParam); result.propagateFuncSpecificPrimalInsts.add(tempPrimalVar); auto storeInst = nextBlockBuilder.emitStore(tempPrimalVar, propParam); result.propagateFuncSpecificPrimalInsts.add(storeInst); result.mapPrimalSpecificParamToReplacementInPropFunc[primalRefReplacement] = tempPrimalVar; } else { // Case 3: non differentiable, non output parameters. // If parameter is not an out param and has nothing to do with differentiation, // simply move the parameter to the end. // fwdParam->removeFromParent(); fwdDiffParameterBlock->addParam(fwdParam); result.primalFuncParams.add(fwdParam); result.propagateFuncParams.add(fwdParam); continue; } } else if (!inoutType) { // Case 4: `in` differentiable parameters. SLANG_RELEASE_ASSERT(diffPairType); // Create inout version. auto inoutDiffPairType = builder->getBorrowInOutParamType(diffPairType); primalRefReplacement = builder->emitParam(primalType); copyNameHintAndDebugDecorations(primalRefReplacement, fwdParam); result.primalFuncParams.add(primalRefReplacement); auto propParam = builder->emitParam(inoutDiffPairType); copyNameHintAndDebugDecorations(propParam, fwdParam); result.propagateFuncParams.add(propParam); // A reference to this parameter from the diff blocks should be replaced with a load // of the differential component of the pair. auto newParamLoad = diffBuilder.emitLoad(propParam); diffBuilder.markInstAsDifferential(newParamLoad, primalType); result.propagateFuncSpecificPrimalInsts.add(newParamLoad); diffRefReplacement = diffBuilder.emitDifferentialPairGetDifferential(diffType, newParamLoad); diffBuilder.markInstAsDifferential(diffRefReplacement, primalType); result.propagateFuncSpecificPrimalInsts.add(diffRefReplacement); // Load the primal component from the prop param and use it as replacement for the // primal param in the primal part of the prop func. // Since these are logic specific to propagate function, we will add them to the // `propagateFuncSpecificPrimalInsts` set so we can remove them when we generate the // primal func. auto primalReplacementLoad = nextBlockBuilder.emitLoad(propParam); result.propagateFuncSpecificPrimalInsts.add(primalReplacementLoad); auto primalVal = nextBlockBuilder.emitDifferentialPairGetPrimal(primalReplacementLoad); result.propagateFuncSpecificPrimalInsts.add(primalVal); result.mapPrimalSpecificParamToReplacementInPropFunc[primalRefReplacement] = primalVal; instsToRemove.add(fwdParam); } else { // Case 5: `inout` differentiable parameters. SLANG_ASSERT(inoutType && diffPairType); // Process differentiable inout parameters. auto primalParam = builder->emitParam(builder->getBorrowInOutParamType(primalType)); copyNameHintAndDebugDecorations(primalParam, fwdParam); result.primalFuncParams.add(primalParam); auto diffParam = builder->emitParam(inoutType); copyNameHintAndDebugDecorations(diffParam, fwdParam); result.propagateFuncParams.add(diffParam); // Primal references to this param is the new primal param. primalRefReplacement = primalParam; // Diff references to this param should be replaced with one local temp var // for read and one separate temp var for write. // Load the inital diff value. auto loadedParam = nextBlockBuilder.emitLoad(diffParam); result.propagateFuncSpecificPrimalInsts.add(loadedParam); auto initDiff = nextBlockBuilder.emitDifferentialPairGetDifferential(diffType, loadedParam); result.propagateFuncSpecificPrimalInsts.add(initDiff); // Create a local var for diff read access. auto diffVar = nextBlockBuilder.emitVar(diffType); copyNameHintAndDebugDecorations(diffVar, fwdParam); result.propagateFuncSpecificPrimalInsts.add(diffVar); diffRefReplacement = diffVar; // Clear the diff read var to zero at start of the function. auto dzero = getDifferentialZeroOfType(&nextBlockBuilder, primalType); result.propagateFuncSpecificPrimalInsts.add(dzero); auto initDiffStore = nextBlockBuilder.emitStore(diffVar, dzero); result.propagateFuncSpecificPrimalInsts.add(initDiffStore); // Create a local var for diff write access. auto diffWriteVar = nextBlockBuilder.emitVar(diffType); result.propagateFuncSpecificPrimalInsts.add(diffWriteVar); copyNameHintAndDebugDecorations(diffWriteVar, fwdParam); // Initialize write var to 0. auto writeStore = nextBlockBuilder.emitStore(diffWriteVar, initDiff); result.propagateFuncSpecificPrimalInsts.add(writeStore); diffWriteRefReplacement = diffWriteVar; // Create a local var for the primal logic in the propagate func. auto primalVar = nextBlockBuilder.emitVar(primalType); copyNameHintAndDebugDecorations(primalVar, fwdParam); result.propagateFuncSpecificPrimalInsts.add(primalVar); auto initPrimalVal = nextBlockBuilder.emitDifferentialPairGetPrimal(loadedParam); result.propagateFuncSpecificPrimalInsts.add(initPrimalVal); auto storeInst = nextBlockBuilder.emitStore(primalVar, initPrimalVal); result.propagateFuncSpecificPrimalInsts.add(storeInst); result.mapPrimalSpecificParamToReplacementInPropFunc[primalParam] = primalVar; result.outDiffWritebacks[diffParam] = InstPair(initPrimalVal, diffVar); instsToRemove.add(fwdParam); } // We have emitted all the new parameters and computed the replacements for the original // parameter. Now we perform that replacement. List uses; for (auto use = fwdParam->firstUse; use; use = use->nextUse) uses.add(use); for (auto use : uses) { if (auto primalRef = as(use->getUser())) { SLANG_RELEASE_ASSERT(primalRefReplacement); primalRef->replaceUsesWith(primalRefReplacement); instsToRemove.add(primalRef); } else if (auto getPrimal = as(use->getUser())) { SLANG_RELEASE_ASSERT(primalRefReplacement); getPrimal->replaceUsesWith(primalRefReplacement); instsToRemove.add(getPrimal); } else if (auto propagateRef = as(use->getUser())) { SLANG_RELEASE_ASSERT(diffRefReplacement); auto refUse = propagateRef->firstUse; while (refUse) { auto nextUse = refUse->nextUse; // Is this use the dest operand of a store inst? // If so, replace it with writeRefReplacement, otherwise, refReplacement. if (refUse->getUser()->getOp() == kIROp_Store && refUse == refUse->getUser()->getOperands()) { SLANG_RELEASE_ASSERT(diffWriteRefReplacement); refUse->set(diffWriteRefReplacement); } else { refUse->set(diffRefReplacement); } refUse = nextUse; } instsToRemove.add(propagateRef); } else if (auto getDiff = as(use->getUser())) { SLANG_RELEASE_ASSERT(diffRefReplacement); getDiff->replaceUsesWith(diffRefReplacement); instsToRemove.add(getDiff); } else { // If the user is something else, it'd better be a non relevant parameter. if (diffRefReplacement || diffWriteRefReplacement) SLANG_UNEXPECTED("unknown use of parameter."); use->set(primalRefReplacement); } } } // Actually remove all the insts that we decided to remove in the process. for (auto inst : instsToRemove) { inst->removeAndDeallocate(); } // The next step is to insert new parameters that is not related to any existing parameters. // // If the return type of the original function is differentiable, // add a parameter for 'derivative of the output' (d_out). // The type is the second last parameter type of the function. // auto paramCount = as(diffFunc->getDataType())->getParamCount(); IRParam* dOutParam = nullptr; if (isResultDifferentiable) { auto dOutParamType = as(diffFunc->getDataType())->getParamType(paramCount - 2); SLANG_ASSERT(dOutParamType); dOutParam = builder->emitParam(dOutParamType); dOutParam->sourceLoc = returnLoc; builder->addNameHintDecoration(dOutParam, UnownedStringSlice("_s_dOut")); result.propagateFuncParams.add(dOutParam); } // Add a parameter for intermediate val. auto ctxParam = builder->emitParam(as(diffFunc->getDataType())->getParamType(paramCount - 1)); builder->addNameHintDecoration(ctxParam, UnownedStringSlice("_s_diff_ctx")); builder->addDecoration(ctxParam, kIROp_PrimalContextDecoration); result.primalFuncParams.add(ctxParam); result.propagateFuncParams.add(ctxParam); result.dOutParam = dOutParam; diffFunc->sourceLoc = primalLoc; ctxParam->sourceLoc = primalLoc; return result; } void BackwardDiffTranscriberBase::writeBackDerivativeToInOutParams( ParameterBlockTransposeInfo& info, IRFunc* diffFunc) { IRInst* returnInst = nullptr; for (auto block : diffFunc->getBlocks()) { for (auto inst : block->getChildren()) { if (inst->getOp() == kIROp_Return) { returnInst = inst; break; } } } SLANG_RELEASE_ASSERT(returnInst); IRBuilder builder(autoDiffSharedContext->moduleInst); builder.setInsertBefore(returnInst); for (auto& wb : info.outDiffWritebacks) { auto dest = wb.key; auto srcPrimalVal = wb.value.primal; auto srcDiffAddr = wb.value.differential; auto srcDiffVal = builder.emitLoad(srcDiffAddr); auto destVal = builder.emitMakeDifferentialPair( as(dest->getFullType())->getValueType(), srcPrimalVal, srcDiffVal); builder.emitStore(dest, destVal); } } InstPair BackwardDiffTranscriberBase::transcribeSpecialize( IRBuilder* builder, IRSpecialize* origSpecialize) { auto primalBase = findOrTranscribePrimalInst(builder, origSpecialize->getBase()); List primalArgs; for (UInt i = 0; i < origSpecialize->getArgCount(); i++) { primalArgs.add(findOrTranscribePrimalInst(builder, origSpecialize->getArg(i))); } auto primalType = findOrTranscribePrimalInst(builder, origSpecialize->getFullType()); auto primalSpecialize = (IRSpecialize*)builder->emitSpecializeInst( (IRType*)primalType, primalBase, primalArgs.getCount(), primalArgs.getBuffer()); if (auto diffBase = instMapD.tryGetValue(origSpecialize->getBase())) { List args; for (UInt i = 0; i < primalSpecialize->getArgCount(); i++) { args.add(primalSpecialize->getArg(i)); } auto diffSpecialize = builder->emitSpecializeInst( builder->getTypeKind(), *diffBase, args.getCount(), args.getBuffer()); return InstPair(primalSpecialize, diffSpecialize); } auto genericInnerVal = findInnerMostGenericReturnVal(as(origSpecialize->getBase())); // Look for an IRBackwardDerivativeDecoration on the specialize inst. // (Normally, this would be on the inner IRFunc, but in this case only the JVP func // can be specialized, so we put a decoration on the IRSpecialize) // if (auto derivativeFunc = findExistingDiffFunc(origSpecialize)) { // Make sure this isn't itself a specialize . SLANG_RELEASE_ASSERT(!as(derivativeFunc)); return InstPair(primalSpecialize, derivativeFunc); } else if (auto diffBase = findExistingDiffFunc(genericInnerVal)) { List args; for (UInt i = 0; i < primalSpecialize->getArgCount(); i++) { args.add(primalSpecialize->getArg(i)); } // A `BackwardDerivative` decoration on an inner func of a generic should always be a // `specialize`. auto diffBaseSpecialize = as(diffBase); SLANG_RELEASE_ASSERT(diffBaseSpecialize); // Note: this assumes that the generic arguments to specialize the derivative is the same as // the generic args to specialize the primal function. This is true for all of our core // module functions, but we may need to rely on more general substitution logic here. auto diffSpecialize = builder->emitSpecializeInst( builder->getTypeKind(), diffBaseSpecialize->getBase(), args.getCount(), args.getBuffer()); return InstPair(primalSpecialize, diffSpecialize); } else if (isBackwardDifferentiableFunc(genericInnerVal) || as(genericInnerVal)) { List args; for (UInt i = 0; i < primalSpecialize->getArgCount(); i++) { args.add(primalSpecialize->getArg(i)); } auto diffCallee = findOrTranscribeDiffInst(builder, origSpecialize->getBase()); auto diffSpecialize = builder->emitSpecializeInst( builder->getTypeKind(), diffCallee, args.getCount(), args.getBuffer()); return InstPair(primalSpecialize, diffSpecialize); } else { return InstPair(primalSpecialize, nullptr); } } } // namespace Slang