diff options
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/diff.meta.slang | 36 | ||||
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 17 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.cpp | 11 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-rev.cpp | 477 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-rev.h | 40 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-transcriber-base.cpp | 72 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-transpose.h | 99 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-unzip.cpp | 108 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-unzip.h | 101 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff.h | 17 | ||||
| -rw-r--r-- | source/slang/slang-ir-clone.cpp | 3 | ||||
| -rw-r--r-- | source/slang/slang-ir-clone.h | 1 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 18 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 36 | ||||
| -rw-r--r-- | source/slang/slang-ir-specialize.cpp | 7 | ||||
| -rw-r--r-- | source/slang/slang-ir-ssa.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir.cpp | 50 |
18 files changed, 848 insertions, 249 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index adbf8ae48..055c44135 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -313,6 +313,24 @@ DifferentialPair<vector<T, N>> __d_sin_vector(DifferentialPair<vector<T, N>> dpx VECTOR_MAP_D_UNARY(T, N, __d_sin, dpx); } +__generic<T : __BuiltinFloatingPointType> +[BackwardDerivativeOf(sin)] +void __d_sin(inout DifferentialPair<T> dpx, T.Differential dOut) +{ + dpx = diffPair( + dpx.p, + T.dmul(cos(dpx.p), dOut)); +} + +__generic<T : __BuiltinFloatingPointType, let N : int> +[BackwardDerivativeOf(sin)] +void __d_sin_vector(inout DifferentialPair<vector<T, N>> dpx, vector<T, N>.Differential dOut) +{ + dpx = diffPair( + dpx.p, + vector<T, N>.dmul(cos(dpx.p), dOut)); +} + // Cosine __generic<T : __BuiltinFloatingPointType> @@ -331,6 +349,24 @@ DifferentialPair<vector<T, N>> __d_cos_vector(DifferentialPair<vector<T, N>> dpx VECTOR_MAP_D_UNARY(T, N, __d_cos, dpx); } +__generic<T : __BuiltinFloatingPointType> +[BackwardDerivativeOf(cos)] +void __d_cos(inout DifferentialPair<T> dpx, T.Differential dOut) +{ + dpx = diffPair( + dpx.p, + T.dmul(-sin(dpx.p), dOut)); +} + +__generic<T : __BuiltinFloatingPointType, let N : int> +[BackwardDerivativeOf(cos)] +void __d_cos_vector(inout DifferentialPair<vector<T, N>> dpx, vector<T, N>.Differential dOut) +{ + dpx = diffPair( + dpx.p, + vector<T, N>.dmul(-sin(dpx.p), dOut)); +} + // Base-e logarithm __generic<T : __BuiltinFloatingPointType> diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index b89eb85c4..a52a08f15 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -478,6 +478,7 @@ namespace Slang if (!parent) return nullptr; + // If we reach here, we are expecting a synthesized decl defined in `subType`. // Instead of returning a DeclRefExpr to the requirement decl, we synthesize a placeholder decl // in `subType` and return a DeclRefExpr to the synthesized decl. @@ -862,6 +863,15 @@ namespace Slang if (auto declRefType = as<DeclRefType>(type)) { + if (auto builtinRequirement = declRefType->declRef.getDecl()->findModifier<BuiltinRequirementModifier>()) + { + if (builtinRequirement->kind == BuiltinRequirementKind::DifferentialType) + { + // We are trying to get differential type from a differential type. + // The result is itself. + return type; + } + } if (auto witness = as<SubtypeWitness>(tryGetInterfaceConformanceWitness(type, builder->getDifferentiableInterface()))) { auto diffTypeLookupResult = lookUpMember( @@ -2328,6 +2338,13 @@ namespace Slang { for (auto param : funcDecl->getParameters()) { + if (param->findModifier<NoDiffModifier>()) + { + if (param->findModifier<OutModifier>() && + !param->findModifier<InModifier>() && + !param->findModifier<InOutModifier>()) + continue; + } resultDiffExpr->newParameterNames.add(param->getName()); } resultDiffExpr->newParameterNames.add(semantics->getName("resultGradient")); diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index a9e716ce4..16b7a977c 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -402,6 +402,7 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig auto primalArg = findOrTranscribePrimalInst(&argBuilder, origArg); SLANG_ASSERT(primalArg); + auto origType = origCall->getArg(ii)->getDataType(); auto primalType = primalArg->getDataType(); auto paramType = calleeType->getParamType(ii); if (!isNoDiffType(paramType)) @@ -410,8 +411,10 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig { while (auto attrType = as<IRAttributedType>(primalType)) primalType = attrType->getBaseType(); + while (auto attrType = as<IRAttributedType>(origType)) + origType = attrType->getBaseType(); } - if (auto pairType = tryGetDiffPairType(&argBuilder, primalType)) + if (auto pairType = tryGetDiffPairType(&argBuilder, origType)) { auto pairPtrType = as<IRPtrTypeBase>(pairType); auto pairValType = as<IRDifferentialPairType>( @@ -1201,7 +1204,7 @@ IRFunc* ForwardDiffTranscriber::transcribeFuncHeaderImpl(IRBuilder* inBuilder, I // Find and clone `DifferentiableTypeDictionaryDecoration` to the new diffFunc. if (auto dictDecor = origFunc->findDecoration<IRDifferentiableTypeDictionaryDecoration>()) { - cloneDecoration(dictDecor, diffFunc); + cloneDecoration(&cloneEnv, dictDecor, diffFunc, diffFunc->getModule()); } return diffFunc; } @@ -1434,8 +1437,10 @@ InstPair ForwardDiffTranscriber::transcribeFuncParam(IRBuilder* builder, IRParam auto ptrInnerPairType = as<IRDifferentialPairType>(pairPtrType->getValueType()); // Make a local copy of the parameter for primal and diff parts. auto primal = builder->emitVar(ptrInnerPairType->getValueType()); + auto diffType = differentiateType(builder, cast<IRPtrTypeBase>(origParam->getDataType())->getValueType()); auto diff = builder->emitVar(diffType); + builder->markInstAsDifferential(diff, ptrInnerPairType->getValueType()); IRInst* primalInitVal = nullptr; IRInst* diffInitVal = nullptr; @@ -1447,6 +1452,8 @@ InstPair ForwardDiffTranscriber::transcribeFuncParam(IRBuilder* builder, IRParam else { auto initVal = builder->emitLoad(diffPairParam); + builder->markInstAsMixedDifferential(initVal, ptrInnerPairType); + primalInitVal = builder->emitDifferentialPairGetPrimal(initVal); diffInitVal = builder->emitDifferentialPairGetDifferential(diffType, initVal); } diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index 9c63a4012..67387e83a 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -417,7 +417,10 @@ namespace Slang } else { - auto var = builder.emitVar(primalParamType); + auto primalPtrType = as<IRPtrTypeBase>(primalParamType); + SLANG_RELEASE_ASSERT(primalPtrType); + auto primalValueType = primalPtrType->getValueType(); + auto var = builder.emitVar(primalValueType); primalArgs.add(var); } primalTypes.add(primalParamType); @@ -515,12 +518,12 @@ namespace Slang { auto ptrType = as<IRPtrTypeBase>(param->getDataType()); auto tempVar = builder.emitVar(ptrType->getValueType()); + param->replaceUsesWith(tempVar); mapParamToTempVar[param] = tempVar; - if (param->getOp() != kIROp_OutType) + if (ptrType->getOp() != kIROp_OutType) { builder.emitStore(tempVar, builder.emitLoad(param)); } - param->replaceUsesWith(tempVar); } for (auto block : func->getBlocks()) @@ -578,6 +581,43 @@ namespace Slang return result; } + void eliminateRedundantLoad(IRFunc* func) + { + for (auto block : func->getBlocks()) + { + for (auto inst = block->getFirstInst(); inst;) + { + auto nextInst = inst->getNextInst(); + if (auto load = as<IRLoad>(inst)) + { + for (auto prev = inst->getPrevInst(); prev; prev = prev->getPrevInst()) + { + if (auto store = as<IRStore>(prev)) + { + if (store->getPtr() == load->getPtr()) + { + // If the load is preceeded by a store without any side-effect insts in-between, remove the load. + auto value = store->getVal(); + load->replaceUsesWith(value); + load->removeAndDeallocate(); + break; + } + } + else if (as<IRCall>(prev)) + { + break; + } + else if (prev->mightHaveSideEffects()) + { + break; + } + } + } + inst = nextInst; + } + } + } + // Create a copy of originalFunc's forward derivative in the same generic context (if any) of // `diffPropagateFunc`. IRFunc* BackwardDiffTranscriberBase::generateNewForwardDerivativeForFunc( @@ -621,6 +661,9 @@ namespace Slang // Remove the clone of original func. primalOuterParent->removeAndDeallocate(); + // Remove redundant loads since they interfere with transposition logic. + eliminateRedundantLoad(fwdDiffFunc); + // Migrate the new forward derivative function into the generic parent of `diffPropagateFunc`. if (auto fwdParentGeneric = as<IRGeneric>(findOuterGeneric(fwdDiffFunc))) { @@ -645,6 +688,7 @@ namespace Slang } fwdParentGeneric->removeAndDeallocate(); } + return fwdDiffFunc; } @@ -665,6 +709,20 @@ namespace Slang return InstPair(primalInst, nullptr); } + // Keep primal param replacement insts alive during DCE. + static void _lockPrimalParamReplacementInsts(IRBuilder* builder, ParameterBlockTransposeInfo& paramInfo) + { + for (auto& kv : paramInfo.mapPrimalSpecificParamToReplacementInPropFunc) + builder->addKeepAliveDecoration(kv.Value); + } + + // Remove [KeepAlive] decorations for primal param replacement insts. + static void _unlockPrimalParamReplacementInsts(ParameterBlockTransposeInfo& paramInfo) + { + for (auto& kv : paramInfo.mapPrimalSpecificParamToReplacementInPropFunc) + kv.Value->findDecoration<IRKeepAliveDecoration>()->removeAndDeallocate(); + } + // Transcribe a function definition. void BackwardDiffTranscriberBase::transcribeFuncImpl(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffPropagateFunc) { @@ -702,9 +760,9 @@ namespace Slang // Copy primal insts to the first block of the unzipped function, copy diff insts to the // second block of the unzipped function. // - IRFunc* unzippedFwdDiffFunc = diffUnzipPass->unzipDiffInsts(fwdDiffFunc); + diffUnzipPass->unzipDiffInsts(fwdDiffFunc); + IRFunc* unzippedFwdDiffFunc = fwdDiffFunc; - // Move blocks from `unzippedFwdDiffFunc` to the `diffPropagateFunc` shell. builder->setInsertInto(diffPropagateFunc->getParent()); { @@ -717,13 +775,19 @@ namespace Slang } // Transpose the first block (parameter block) - List<IRInst*> primalFuncSpecificParams; - auto dOutParameter = transposeParameterBlock(builder, diffPropagateFunc, primalFuncSpecificParams, isResultDifferentiable); + auto paramTransposeInfo = + splitAndTransposeParameterBlock(builder, diffPropagateFunc, isResultDifferentiable); + + // The insts we inserted in paramTransposeInfo.mapPrimalSpecificParamToReplacementInPropFunc + // may be used by write back logic that we are going to insert later. + // Before then we want to keep them alive. + _lockPrimalParamReplacementInsts(builder, paramTransposeInfo); builder->setInsertInto(diffPropagateFunc); - // Transpose differential blocks from unzippedFwdDiffFunc into diffFunc (with dOutParameter) representing the - DiffTransposePass::FuncTranspositionInfo info = {dOutParameter, nullptr}; + // Transpose differential blocks from unzippedFwdDiffFunc into diffFunc (with dOutParameter) representing the + // derivative of the return value. + DiffTransposePass::FuncTranspositionInfo info = { paramTransposeInfo.dOutParam, nullptr}; diffTransposePass->transposeDiffBlocksInFunc(diffPropagateFunc, info); eliminateDeadCode(diffPropagateFunc); @@ -732,32 +796,36 @@ namespace Slang // with the intermediate results computed from the extracted func. IRInst* intermediateType = nullptr; auto extractedPrimalFunc = diffUnzipPass->extractPrimalFunc( - diffPropagateFunc, primalFunc, isResultDifferentiable, intermediateType); + diffPropagateFunc, primalFunc, paramTransposeInfo, intermediateType); - // Clean up by deallocating the tempoarary forward derivative func. - fwdDiffFunc->removeAndDeallocate(); + // At this point the unzipped func is just an empty shell + // and we can simply remove it. + unzippedFwdDiffFunc->removeAndDeallocate(); - // Remove primalFuncSpecificParams. - for (auto specificParam : primalFuncSpecificParams) + // Write back derivatives to inout parameters. + writeBackDerivativeToInOutParams(paramTransposeInfo, diffPropagateFunc); + + // Remove primalFunc specific params. + List<IRInst*> paramsToRemove; + for (auto param : diffPropagateFunc->getParams()) + { + if (!paramTransposeInfo.propagateFuncParams.Contains(param)) + paramsToRemove.add(param); + } + for (auto param : paramsToRemove) { - while (auto use = specificParam->firstUse) + if (param->hasUses()) { - if (use->getUser()->getOp() == kIROp_Store && use == use->getUser()->getOperands()) - { - use->getUser()->removeAndDeallocate(); - } - else if (auto decor = as<IRDecoration>(use->getUser())) - { - decor->removeAndDeallocate(); - } - else - { - SLANG_UNEXPECTED("unexpected use of transcribed param."); - } + IRInst* replacement = nullptr; + paramTransposeInfo.mapPrimalSpecificParamToReplacementInPropFunc.TryGetValue(param, replacement); + SLANG_RELEASE_ASSERT(replacement); + param->replaceUsesWith(replacement); } - specificParam->removeAndDeallocate(); + param->removeAndDeallocate(); } + _unlockPrimalParamReplacementInsts(paramTransposeInfo); + // If primal function is nested in a generic, we want to create separate generics for all the associated things // we have just created. auto primalOuterGeneric = findOuterGeneric(primalFunc); @@ -789,87 +857,322 @@ namespace Slang initializeLocalVariables(builder->getSharedBuilder(), diffPropagateFunc); } - IRInst* BackwardDiffTranscriberBase::transposeParameterBlock( + ParameterBlockTransposeInfo BackwardDiffTranscriberBase::splitAndTransposeParameterBlock( IRBuilder* builder, IRFunc* diffFunc, - List<IRInst*>& primalFuncSpecificParams, bool isResultDifferentiable) { + // This method splits transposes the all the parameters for both the primal and propagate computation. + // At the end of this method, the parameter block will contain a combination of parameters for + // both the to-be-primal function and to-be-propagate function. + // We use ParameterBlockTransposeInfo::primalFuncParams and ParameterBlockTransposeInfo::propagateFuncParams + // to track which parameters are dedicated to the future primal or propagate func. + // A later step will then split the parameters out to each new function. + + ParameterBlockTransposeInfo result; + + // First, we initialize the IR builders and locate the import code insertion points that will + // be used for the rest of this method. + IRBlock* fwdDiffParameterBlock = diffFunc->getFirstBlock(); // Find the 'next' block using the terminator inst of the parameter block. auto fwdParamBlockBranch = as<IRUnconditionalBranch>(fwdDiffParameterBlock->getTerminator()); auto nextBlock = fwdParamBlockBranch->getTargetBlock(); - builder->setInsertBefore(fwdParamBlockBranch); + auto nextBlockBuilder = *builder; + nextBlockBuilder.setInsertBefore(nextBlock->getFirstOrdinaryInst()); - List<IRParam*> fwdParams; - for (auto child = fwdDiffParameterBlock->getFirstParam(); child; child = child->getNextParam()) + IRBlock* firstDiffBlock = nullptr; + for (auto block : diffFunc->getBlocks()) { - fwdParams.add(child); + if (isDifferentialInst(block)) + { + firstDiffBlock = block; + break; + } } - // 1. Turn fwd-diff versions of the parameters into reverse-diff versions by wrapping them as InOutType<> + SLANG_RELEASE_ASSERT(firstDiffBlock); + + auto diffBuilder = *builder; + diffBuilder.setInsertBefore(firstDiffBlock->getFirstOrdinaryInst()); + + builder->setInsertBefore(fwdParamBlockBranch); + + // Collect all the original parameters. + List<IRParam*> fwdParams; + for (auto param : diffFunc->getParams()) + fwdParams.add(param); + + // Maintain a set for insts pending removal. + OrderedHashSet<IRInst*> instsToRemove; + + // Now we begin the actual processing. + // The first step is to transcribe all the existing parameters from the original function. + // There are many cases to handle, including different combinations of parameter directions and + // whether or not the parameter is differentiable. + // To normalize the process for all these cases, we determine the following actions for each parameter: + // 1. Should this original parameter be translated to a parameter in the primal func and the propagate func? + // if so, we emit a param inst representing the final parameter for that func. If the parameter should be + // mapped to both the primal func and the propagate func, we will emit two separate params with their + // final type. + // 2. If this parameter has a corresponding primal func parameter, we replace all uses of the original + // parameter in the primal computation code to the new primal parameter. If any initialization logic + // is needed to convert the type of the new primal parameter to what the code was expecting, we insert + // that code in the first block. + // 3. If this parameter has a correponding propagate func parameter, we replace all uses of the original parameter + // in the diff computation code to the new propagate parameter. We insert necessary initialization diff block or the first block + // depending on whether we want that logic go through the transposition pass. We may need to replace the uses + // to different values/variables depending on whether that use is a read or write. + // 4. If the parameter has both corresponding primal and propagate parameters, we also need to consider + // how the future propagate function access the primal parameter. We will insert necessary preparation code + // that constructs temp vars or values to replace the primal parameter after we remove it from the + // propagate func. + // Base on above discussion, we need to compute the following values for each parameter: + // - diffRefReplacement. What should all read(load) references to this parameter from differential code be replaced to. + // - diffRefWriteReplacement. What should all write references to this parameter from differential code be replaced to. + // - primalRefReplacement. What should all references to this parameter from primal code be replaced to. + // - mapPrimalSpecificParamToReplacementInPropFunc[param]. What should all references to this parameter + // from the primal compuation logic in the future propagate function be replaced to. for (auto fwdParam : fwdParams) { - if (auto outType = as<IROutType>(fwdParam->getDataType())) + // Define the replacement insts that we are going to fill in for each case. + IRInst* diffRefReplacement = nullptr; + IRInst* primalRefReplacement = nullptr; + IRInst* diffWriteRefReplacement = nullptr; + + // Common logic that computes all the important types we care about. + IRDifferentialPairType* diffPairType = as<IRDifferentialPairType>(fwdParam->getDataType()); + auto inoutType = as<IRInOutType>(fwdParam->getDataType()); + auto outType = as<IROutType>(fwdParam->getDataType()); + if (inoutType) + diffPairType = as<IRDifferentialPairType>(inoutType->getValueType()); + else if (outType) + diffPairType = as<IRDifferentialPairType>(outType->getValueType()); + IRType* primalType = nullptr; + IRType* diffType = nullptr; + if (diffPairType) + { + primalType = diffPairType->getValueType(); + diffType = (IRType*)differentiableTypeConformanceContext + .getDifferentialTypeFromDiffPairType(builder, diffPairType); + } + + // Now we handle each combination of parameter direction x differentiability. + if (outType) { - IRParam* newPropParam = nullptr; - IRParam* newPrimalParam = nullptr; - auto diffPairType = as<IRDifferentialPairType>(outType->getValueType()); + // Case 1: out parameters. + // Out parameters need to be handled differently whether or not it is differentiable, + // since the propagate function will not have a corresponding output. if (diffPairType) { // Create dOut param. - auto diffType = (IRType*)differentiableTypeConformanceContext.getDifferentialTypeFromDiffPairType(builder, diffPairType); - newPropParam = builder->emitParam(diffType); - newPrimalParam = builder->emitParam(builder->getOutType(diffPairType->getValueType())); + auto diffParam = builder->emitParam(diffType); + result.propagateFuncParams.Add(diffParam); + primalRefReplacement = builder->emitParam(builder->getOutType(primalType)); + + // Create a local var for read access in pre-transpose code. + // This will the var from which we will fetch the final resulting derivative + // after transposition. + auto tempVar = nextBlockBuilder.emitVar(diffType); + nextBlockBuilder.markInstAsDifferential(tempVar, diffPairType); + + // Initialize the var with input diff param at start. + // Note that we insert the store in the primal block so it won't get transposed. + auto storeInst = nextBlockBuilder.emitStore(tempVar, diffParam); + nextBlockBuilder.markInstAsDifferential(storeInst, diffPairType); + // Since this store inst is specific to propagate function, we track it in a + // set so we can remove it when we generate the primal func. + result.propagateFuncSpecificPrimalInsts.add(storeInst); + + diffWriteRefReplacement = tempVar; + diffRefReplacement = tempVar; } else { - newPrimalParam = builder->emitParam(outType); - } - - // Create a temp var to represent the original `out` param. - auto arg = builder->emitVar(outType->getValueType()); - builder->addAutoDiffOriginalValueDecoration(arg, newPrimalParam); - if (newPropParam) - { - builder->addDecoration(arg, kIROp_OutParamReverseGradientDecoration, newPropParam); + primalRefReplacement = builder->emitParam(outType); } + result.primalFuncParams.Add(primalRefReplacement); - fwdParam->replaceUsesWith(arg); - fwdParam->removeAndDeallocate(); + // Create a local var for the out param for the primal part of the prop func. + auto tempPrimalVar = nextBlockBuilder.emitVar(outType->getValueType()); + result.mapPrimalSpecificParamToReplacementInPropFunc[primalRefReplacement] = tempPrimalVar; - primalFuncSpecificParams.add(newPrimalParam); + instsToRemove.Add(fwdParam); } - else if (auto diffPairType = as<IRDifferentialPairType>(fwdParam->getDataType())) + else if (!isRelevantDifferentialPair(fwdParam->getDataType())) { + // Case 2: non differentiable, non output parameters. + // If parameter is not an out param and has nothing to do with differentiation, + // simply move the parameter to the end. + // + fwdParam->removeFromParent(); + fwdDiffParameterBlock->addParam(fwdParam); + result.primalFuncParams.Add(fwdParam); + result.propagateFuncParams.Add(fwdParam); + continue; + } + else if(!inoutType) + { + // Case 4: `in` differentiable parameters. + + SLANG_RELEASE_ASSERT(diffPairType); + // Create inout version. auto inoutDiffPairType = builder->getInOutType(diffPairType); - auto newParam = builder->emitParam(inoutDiffPairType); - - // Map the _load_ of the new parameter as the clone of the old one. - auto newParamLoad = builder->emitLoad(newParam); - newParamLoad->insertAtStart(nextBlock); // Move to first block _after_ the parameter block. - fwdParam->replaceUsesWith(newParamLoad); - fwdParam->removeAndDeallocate(); + primalRefReplacement = builder->emitParam(primalType); + result.primalFuncParams.Add(primalRefReplacement); + auto propParam = builder->emitParam(inoutDiffPairType); + result.propagateFuncParams.Add(propParam); + + // A reference to this parameter from the diff blocks should be replaced with a load + // of the differential component of the pair. + auto newParamLoad = diffBuilder.emitLoad(propParam); + diffBuilder.markInstAsDifferential(newParamLoad, primalType); + + diffRefReplacement = diffBuilder.emitDifferentialPairGetDifferential(diffType, newParamLoad); + diffBuilder.markInstAsDifferential(diffRefReplacement, primalType); + + // Load the primal component from the prop param and use it as replacement for the + // primal param in the primal part of the prop func. + // Since these are logic specific to propagate function, we will add them to the + // `propagateFuncSpecificPrimalInsts` set so we can remove them when we generate the primal func. + auto primalReplacementLoad = nextBlockBuilder.emitLoad(propParam); + result.propagateFuncSpecificPrimalInsts.add(primalReplacementLoad); + auto primalVal = nextBlockBuilder.emitDifferentialPairGetPrimal(primalReplacementLoad); + result.propagateFuncSpecificPrimalInsts.add(primalVal); + result.mapPrimalSpecificParamToReplacementInPropFunc[primalRefReplacement] = primalVal; + + instsToRemove.Add(fwdParam); } else { - // Default case (parameter is inout type or has nothing to do with differentiation) - // Simply move the parameter to the end. - // - fwdParam->removeFromParent(); - fwdDiffParameterBlock->addParam(fwdParam); + // Case 5: `inout` differentiable parameters. + SLANG_ASSERT(inoutType && diffPairType); + + // Process differentiable inout parameters. + auto primalParam = builder->emitParam(builder->getInOutType(primalType)); + result.primalFuncParams.Add(primalParam); + + auto diffParam = builder->emitParam(inoutType); + result.propagateFuncParams.Add(diffParam); + + // Primal references to this param is the new primal param. + primalRefReplacement = primalParam; + + // Diff references to this param should be replaced with one local temp var + // for read and one separate temp var for write. + + // Load the inital diff value. + auto loadedParam = nextBlockBuilder.emitLoad(diffParam); + auto initDiff = nextBlockBuilder.emitDifferentialPairGetDifferential(diffType, loadedParam); + + // Create a local var for diff read access. + auto diffVar = nextBlockBuilder.emitVar(diffType); + result.propagateFuncSpecificPrimalInsts.add(diffVar); + diffBuilder.markInstAsDifferential(diffVar, diffPairType); + diffRefReplacement = diffVar; + + // Clear the diff read var to zero at start of the function. + auto dzero = getDifferentialZeroOfType(&nextBlockBuilder, primalType); + result.propagateFuncSpecificPrimalInsts.add(dzero); + auto initDiffStore = nextBlockBuilder.emitStore(diffVar, dzero); + result.propagateFuncSpecificPrimalInsts.add(initDiffStore); + + // Create a local var for diff write access. + auto diffWriteVar = nextBlockBuilder.emitVar(diffType); + // Initialize write var to 0. + auto writeStore = nextBlockBuilder.emitStore(diffWriteVar, initDiff); + result.propagateFuncSpecificPrimalInsts.add(writeStore); + + diffWriteRefReplacement = diffWriteVar; + + // Create a local var for the primal logic in the propagate func. + auto primalVar = nextBlockBuilder.emitVar(primalType); + result.propagateFuncSpecificPrimalInsts.add(primalVar); + auto initPrimalVal = nextBlockBuilder.emitDifferentialPairGetPrimal(loadedParam); + result.propagateFuncSpecificPrimalInsts.add(initPrimalVal); + auto storeInst = nextBlockBuilder.emitStore(primalVar, initPrimalVal); + result.propagateFuncSpecificPrimalInsts.add(storeInst); + result.mapPrimalSpecificParamToReplacementInPropFunc[primalParam] = primalVar; + result.outDiffWritebacks[diffParam] = InstPair(initPrimalVal, diffVar); + + instsToRemove.Add(fwdParam); + } + + // We have emitted all the new parameters and computed the replacements for the original + // parameter. Now we perform that replacement. + List<IRUse*> uses; + for (auto use = fwdParam->firstUse; use; use = use->nextUse) + uses.add(use); + for (auto use : uses) + { + if (auto primalRef = as<IRPrimalParamRef>(use->getUser())) + { + SLANG_RELEASE_ASSERT(primalRefReplacement); + primalRef->replaceUsesWith(primalRefReplacement); + instsToRemove.Add(primalRef); + } + else if (auto getPrimal = as<IRDifferentialPairGetPrimal>(use->getUser())) + { + SLANG_RELEASE_ASSERT(primalRefReplacement); + getPrimal->replaceUsesWith(primalRefReplacement); + instsToRemove.Add(getPrimal); + } + else if (auto propagateRef = as<IRDiffParamRef>(use->getUser())) + { + SLANG_RELEASE_ASSERT(diffRefReplacement); + auto refUse = propagateRef->firstUse; + while (refUse) + { + auto nextUse = refUse->nextUse; + switch (refUse->getUser()->getOp()) + { + case kIROp_Load: + refUse->set(diffRefReplacement); + break; + case kIROp_Store: + refUse->set(diffWriteRefReplacement); + break; + default: + SLANG_RELEASE_ASSERT(!diffWriteRefReplacement); + refUse->set(diffRefReplacement); + } + refUse = nextUse; + } + instsToRemove.Add(propagateRef); + } + else if (auto getDiff = as<IRDifferentialPairGetDifferential>(use->getUser())) + { + SLANG_RELEASE_ASSERT(diffRefReplacement); + getDiff->replaceUsesWith(diffRefReplacement); + instsToRemove.Add(getDiff); + } + else + { + // If the user is something else, it'd better be a non relevant parameter. + if (diffRefReplacement || diffWriteRefReplacement) + SLANG_UNEXPECTED("unknown use of parameter."); + use->set(primalRefReplacement); + } } } - auto paramCount = as<IRFuncType>(diffFunc->getDataType())->getParamCount(); + // Actually remove all the insts that we decided to remove in the process. + for (auto inst : instsToRemove) + { + inst->removeAndDeallocate(); + } + - // 2. If the return type of the original function is differentiable, + // The next step is to insert new parameters that is not related to any existing parameters. + // + // If the return type of the original function is differentiable, // add a parameter for 'derivative of the output' (d_out). // The type is the second last parameter type of the function. // + auto paramCount = as<IRFuncType>(diffFunc->getDataType())->getParamCount(); IRParam* dOutParam = nullptr; if (isResultDifferentiable) { @@ -878,12 +1181,44 @@ namespace Slang SLANG_ASSERT(dOutParamType); dOutParam = builder->emitParam(dOutParamType); + result.propagateFuncParams.Add(dOutParam); } // Add a parameter for intermediate val. - builder->emitParam(as<IRFuncType>(diffFunc->getDataType())->getParamType(paramCount - 1)); + auto ctxParam = builder->emitParam(as<IRFuncType>(diffFunc->getDataType())->getParamType(paramCount - 1)); + result.primalFuncParams.Add(ctxParam); + result.propagateFuncParams.Add(ctxParam); + result.dOutParam = dOutParam; + return result; + } - return dOutParam; + void BackwardDiffTranscriberBase::writeBackDerivativeToInOutParams(ParameterBlockTransposeInfo& info, IRFunc* diffFunc) + { + IRInst* returnInst = nullptr; + for (auto block : diffFunc->getBlocks()) + { + for (auto inst : block->getChildren()) + { + if (inst->getOp() == kIROp_Return) + { + returnInst = inst; + break; + } + } + } + SLANG_RELEASE_ASSERT(returnInst); + + IRBuilder builder(sharedBuilder); + builder.setInsertBefore(returnInst); + for (auto& wb : info.outDiffWritebacks) + { + auto dest = wb.Key; + auto srcPrimalVal = wb.Value.primal; + auto srcDiffAddr = wb.Value.differential; + auto srcDiffVal = builder.emitLoad(srcDiffAddr); + auto destVal = builder.emitMakeDifferentialPair(as<IRPtrTypeBase>(dest->getFullType())->getValueType(), srcPrimalVal, srcDiffVal); + builder.emitStore(dest, destVal); + } } InstPair BackwardDiffTranscriberBase::transcribeSpecialize(IRBuilder* builder, IRSpecialize* origSpecialize) diff --git a/source/slang/slang-ir-autodiff-rev.h b/source/slang/slang-ir-autodiff-rev.h index 617e6b79b..c8e115f0e 100644 --- a/source/slang/slang-ir-autodiff-rev.h +++ b/source/slang/slang-ir-autodiff-rev.h @@ -20,6 +20,31 @@ struct IRReverseDerivativePassOptions // Nothing for now.. }; +// The result of function parameter transposition. +// Contains necessary info for future processing in the backward differentation pass. +struct ParameterBlockTransposeInfo +{ + // Parameters that should be in the furture primal function. + HashSet<IRInst*> primalFuncParams; + + // Parameters that should be in the furture propagate function. + HashSet<IRInst*> propagateFuncParams; + + // The value with which a primal specific parameter should be replaced in propagate func. + OrderedDictionary<IRInst*, IRInst*> mapPrimalSpecificParamToReplacementInPropFunc; + + // The insts added that is specific for propagate functions and should be removed + // from the future primal func. + List<IRInst*> propagateFuncSpecificPrimalInsts; + + // Write backs to perform at the end of the back-prop function in order to return the + // computed output derivatives for an inout parameter. + OrderedDictionary<IRInst*, InstPair> outDiffWritebacks; + + // The dOut parameter representing the result derivative to propagate backwards through. + IRInst* dOutParam; +}; + struct BackwardDiffTranscriberBase : AutoDiffTranscriberBase { FuncBodyTranscriptionTaskType diffTaskType; @@ -70,9 +95,18 @@ struct BackwardDiffTranscriberBase : AutoDiffTranscriberBase // Transcribe a function definition. virtual InstPair transcribeFunc(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffFunc) = 0; - // Transcribes the parameter block and returns the dOut param if exists. - IRInst* transposeParameterBlock(IRBuilder* builder, IRFunc* diffFunc, List<IRInst*>& primalFuncSpecificParams, bool isResultDifferentiable); - + // Splits and transpose the parameter block. + // After this operation, the parameter block will contain parameters for both the future + // primal func and the future propagate func. + // Additional info is returned in `ParameterBlockTransposeInfo` for future processing such + // as inserting write-back logic or splitting them into different functions. + ParameterBlockTransposeInfo splitAndTransposeParameterBlock( + IRBuilder* builder, + IRFunc* diffFunc, + bool isResultDifferentiable); + + void writeBackDerivativeToInOutParams(ParameterBlockTransposeInfo& info, IRFunc* diffFunc); + InstPair transcribeFuncParam(IRBuilder* builder, IRParam* origParam, IRInst* primalType); InstPair transcribeSpecialize(IRBuilder* builder, IRSpecialize* origSpecialize); diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index 8f21e8c62..31a3072c0 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -221,7 +221,19 @@ IRType* AutoDiffTranscriberBase::getOrCreateDiffPairType(IRBuilder* builder, IRI auto primalType = lookupPrimalInst(builder, originalType, nullptr); SLANG_RELEASE_ASSERT(primalType); - IRInst* witness = tryGetDifferentiableWitness(builder, originalType); + IRInst* witness = nullptr; + if (auto lookup = as<IRLookupWitnessMethod>(primalType)) + { + if (lookup->getRequirementKey() == autoDiffSharedContext->differentialAssocTypeStructKey) + { + witness = builder->emitLookupInterfaceMethodInst( + lookup->getWitnessTable()->getDataType(), + lookup->getWitnessTable(), + autoDiffSharedContext->differentialAssocTypeWitnessStructKey); + } + } + if (!witness) + witness = tryGetDifferentiableWitness(builder, originalType); SLANG_RELEASE_ASSERT(witness); return builder->getDifferentialPairType( @@ -239,6 +251,10 @@ IRType* AutoDiffTranscriberBase::differentiateType(IRBuilder* builder, IRType* o auto diffType = (IRType*)differentiableTypeConformanceContext.getDifferentialForType(builder, origType); return (IRType*)findOrTranscribePrimalInst(builder, diffType); } + else if (origType->getOp() == kIROp_LookupWitness) + { + return (IRType*)findOrTranscribePrimalInst(builder, (IRInst*)primalType); + } return (IRType*)transcribe(builder, origType); } @@ -539,6 +555,39 @@ InstPair AutoDiffTranscriberBase::transcribeLookupInterfaceMethod(IRBuilder* bui { return InstPair(primal, nullptr); } + if (interfaceType == autoDiffSharedContext->differentiableInterfaceType) + { + if (primalKey == autoDiffSharedContext->differentialAssocTypeStructKey) + { + return InstPair(primal, primal); + } + else if (primalKey == autoDiffSharedContext->differentialAssocTypeWitnessStructKey) + { + return InstPair(primal, primal); + } + else + { + // We can't really differentiate a call to a IDifferentiable method here. + // They need to be specialized first. + return InstPair(primal, nullptr); + } + } + else if (auto returnWitnessType = as<IRWitnessTableTypeBase>(lookupInst->getDataType())) + { + // T.Diff_Is_IDifferential ==> T.Diff_Is_IDifferential.Diff_Is_IDifferential + if (returnWitnessType->getConformanceType() == autoDiffSharedContext->differentiableInterfaceType) + { + auto primalDiffType = builder->emitLookupInterfaceMethodInst( + builder->getTypeKind(), + primal, + autoDiffSharedContext->differentialAssocTypeStructKey); + auto diffWitness = builder->emitLookupInterfaceMethodInst( + (IRType*)primalDiffType, + primal, + autoDiffSharedContext->differentialAssocTypeWitnessStructKey); + return InstPair(primal, diffWitness); + } + } auto decor = lookupInst->getRequirementKey()->findDecorationImpl( getInterfaceRequirementDerivativeDecorationOp()); @@ -563,6 +612,8 @@ InstPair AutoDiffTranscriberBase::transcribeLookupInterfaceMethod(IRBuilder* bui // IRInst* AutoDiffTranscriberBase::getDifferentialZeroOfType(IRBuilder* builder, IRType* primalType) { + primalType = (IRType*)unwrapAttributedType(primalType); + if (auto diffType = differentiateType(builder, primalType)) { switch (diffType->getOp()) @@ -593,17 +644,18 @@ IRInst* AutoDiffTranscriberBase::getDifferentialZeroOfType(IRBuilder* builder, I // Since primalType has a corresponding differential type, we can lookup the // definition for zero(). - auto zeroMethod = differentiableTypeConformanceContext.getZeroMethodForType(builder, primalType); - if (!zeroMethod) + IRInst* zeroMethod = nullptr; + if (auto lookupInterface = as<IRLookupWitnessMethod>(diffType)) { // if the differential type itself comes from a witness lookup, we can just lookup the // zero method from the same witness table. - if (auto lookupInterface = as<IRLookupWitnessMethod>(diffType)) - { - auto wt = lookupInterface->getWitnessTable(); - zeroMethod = builder->emitLookupInterfaceMethodInst(builder->getFuncType(List<IRType*>(), diffType), wt, autoDiffSharedContext->zeroMethodStructKey); - builder->markInstAsDifferential(zeroMethod); - } + auto wt = lookupInterface->getWitnessTable(); + zeroMethod = builder->emitLookupInterfaceMethodInst(builder->getFuncType(List<IRType*>(), diffType), wt, autoDiffSharedContext->zeroMethodStructKey); + builder->markInstAsDifferential(zeroMethod); + } + else + { + zeroMethod = differentiableTypeConformanceContext.getZeroMethodForType(builder, primalType); } SLANG_RELEASE_ASSERT(zeroMethod); @@ -747,6 +799,8 @@ static void _markGenericChildrenWithoutRelaventUse(IRGeneric* origGeneric, HashS case kIROp_UserDefinedBackwardDerivativeDecoration: case kIROp_ForwardDerivativeDecoration: case kIROp_BackwardDerivativeDecoration: + case kIROp_BackwardDerivativeIntermediateTypeDecoration: + case kIROp_BackwardDerivativePrimalContextDecoration: case kIROp_BackwardDerivativePrimalDecoration: case kIROp_BackwardDerivativePropagateDecoration: break; diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h index 2a341ed38..92bd0b0a8 100644 --- a/source/slang/slang-ir-autodiff-transpose.h +++ b/source/slang/slang-ir-autodiff-transpose.h @@ -786,6 +786,12 @@ struct DiffTransposePass for (auto externInst : externInsts) { + if (isNoDiffType(externInst->getDataType())) + { + popRevGradients(externInst); + continue; + } + auto primalType = tryGetPrimalTypeFromDiffInst(externInst); SLANG_ASSERT(primalType); @@ -960,11 +966,24 @@ struct DiffTransposePass return as<IRDifferentialPairType>(type); }; + struct DiffValWriteBack + { + IRInst* destVar; + IRInst* srcTempPairVar; + }; + List<DiffValWriteBack> writebacks; + + auto baseFnType = as<IRFuncType>(baseFn->getDataType()); + + SLANG_RELEASE_ASSERT(baseFnType); + SLANG_RELEASE_ASSERT(fwdCall->getArgCount() == baseFnType->getParamCount()); + for (UIndex ii = 0; ii < fwdCall->getArgCount(); ii++) { auto arg = fwdCall->getArg(ii); - - if (arg->getOp() == kIROp_LoadReverseGradient) + auto paramType = baseFnType->getParamType(ii); + + if (as<IRLoadReverseGradient>(arg)) { // Original parameters that are `out DifferentiableType` will turn into // a `in Differential` parameter. The split logic will insert LoadReverseGradient insts @@ -972,6 +991,20 @@ struct DiffTransposePass // and use it as the final argument. args.add(builder->emitLoad(arg->getOperand(0))); } + else if (auto instPair = as<IRReverseGradientDiffPairRef>(arg)) + { + // An argument to an inout parameter will come in the form of a ReverseGradientDiffPairRef(primalVar, diffVar) inst + // after splitting. + // In order to perform the call, we need a temporary var to store the DiffPair. + auto pairType = as<IRPtrTypeBase>(arg->getDataType())->getValueType(); + auto tempVar = builder->emitVar(pairType); + auto primalVal = builder->emitLoad(instPair->getPrimal()); + auto diffVal = builder->emitLoad(instPair->getDiff()); + auto pairVal = builder->emitMakeDifferentialPair(pairType, primalVal, diffVal); + builder->emitStore(tempVar, pairVal); + args.add(tempVar); + writebacks.add(DiffValWriteBack{instPair->getDiff(), tempVar}); + } else if (!as<IRPtrTypeBase>(arg->getDataType()) && getDiffPairType(arg->getDataType())) { // Normal differentiable input parameter will become an inout DiffPair parameter @@ -1002,9 +1035,17 @@ struct DiffTransposePass } else { - args.add(arg); - argTypes.add(arg->getDataType()); - argRequiresLoad.add(false); + if (as<IROutType>(paramType)) + { + args.add(nullptr); + argRequiresLoad.add(false); + } + else + { + args.add(arg); + argTypes.add(arg->getDataType()); + argRequiresLoad.add(false); + } } } @@ -1015,11 +1056,10 @@ struct DiffTransposePass argRequiresLoad.add(false); } - args.add(primalContextDecor->getBackwardDerivativePrimalContextVar()); - argTypes.add(builder->getOutType( - as<IRPtrTypeBase>( + args.add(builder->emitLoad(primalContextDecor->getBackwardDerivativePrimalContextVar())); + argTypes.add(as<IRPtrTypeBase>( primalContextDecor->getBackwardDerivativePrimalContextVar()->getDataType()) - ->getValueType())); + ->getValueType()); argRequiresLoad.add(false); auto revFnType = builder->getFuncType(argTypes, builder->getVoidType()); @@ -1027,11 +1067,27 @@ struct DiffTransposePass revFnType, baseFn); - builder->emitCallInst(revFnType->getResultType(), revCallee, args); + List<IRInst*> callArgs; + for (auto arg : args) + if (arg) + callArgs.add(arg); + builder->emitCallInst(revFnType->getResultType(), revCallee, callArgs); + + // Writeback result gradient to their corresponding splitted variable. + for (auto wb : writebacks) + { + auto loadedPair = builder->emitLoad(wb.srcTempPairVar); + auto diffType = as<IRPtrTypeBase>(wb.destVar->getDataType())->getValueType(); + auto loadedDiff = builder->emitDifferentialPairGetDifferential(diffType, loadedPair); + builder->emitStore(wb.destVar, loadedDiff); + } List<RevGradient> gradients; - for (UIndex ii = 0; ii < fwdCall->getArgCount(); ii++) + for (Index ii = 0; ii < args.getCount(); ii++) { + if (!args[ii]) + continue; + // Is this arg relevant to auto-diff? if (auto diffPairType = getDiffPairType(args[ii]->getDataType())) { @@ -1043,13 +1099,11 @@ struct DiffTransposePass auto diffArgType = (IRType*)diffTypeContext.getDifferentialForType( builder, diffPairType->getValueType()); - auto diffArgPtrType = builder->getPtrType(kIROp_PtrType, diffArgType); - gradients.add(RevGradient( RevGradient::Flavor::Simple, fwdCall->getArg(ii), builder->emitDifferentialPairGetDifferential( - diffArgPtrType, builder->emitLoad(args[ii])), + diffArgType, builder->emitLoad(args[ii])), nullptr)); } } @@ -1308,14 +1362,13 @@ struct DiffTransposePass TranspositionResult transposeStore(IRBuilder* builder, IRStore* fwdStore, IRInst*) { - IRInst* revVal = nullptr; - if (auto revGradDecor = fwdStore->getPtr()->findDecoration<IROutParamReverseGradientDecoration>()) + IRInst* revVal = builder->emitLoad(fwdStore->getPtr()); + if (auto diffPairType = as<IRDifferentialPairType>(revVal->getDataType())) { - revVal = revGradDecor->getValue(); - } - else - { - revVal = builder->emitLoad(fwdStore->getPtr()); + revVal = builder->emitDifferentialPairGetDifferential( + (IRType*)diffTypeContext.getDifferentialTypeFromDiffPairType( + builder, diffPairType), + revVal); } return TranspositionResult( List<RevGradient>( @@ -1747,7 +1800,8 @@ struct DiffTransposePass for (UIndex ii = 0; ii < fwdInst->getOperandCount(); ii++) { auto operand = fwdInst->getOperand(ii); - if (operand->getDataType() != targetType) + auto operandType = unwrapAttributedType(operand->getDataType()); + if (operandType != targetType) { // Insert new operand just after the old operand, so we have the old // operands available. @@ -2271,7 +2325,6 @@ struct DiffTransposePass { gradientsMap[fwdInst] = List<RevGradient>(); } - gradientsMap[fwdInst].GetValue().add(assignment); } diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp index 378ea1cc2..be20d8aa8 100644 --- a/source/slang/slang-ir-autodiff-unzip.cpp +++ b/source/slang/slang-ir-autodiff-unzip.cpp @@ -1,6 +1,7 @@ #include "slang-ir-autodiff-unzip.h" #include "slang-ir-ssa-simplification.h" #include "slang-ir-util.h" +#include "slang-ir-autodiff-rev.h" namespace Slang { @@ -279,7 +280,7 @@ struct ExtractPrimalFuncContext inst); } - IRFunc* turnUnzippedFuncIntoPrimalFunc(IRFunc* unzippedFunc, IRFunc* originalFunc, bool isResultDifferentiable, IRInst*& outIntermediateType) + IRFunc* turnUnzippedFuncIntoPrimalFunc(IRFunc* unzippedFunc, IRFunc* originalFunc, HashSet<IRInst*>& primalParams, IRInst*& outIntermediateType) { IRBuilder builder(sharedBuilder); @@ -294,6 +295,7 @@ struct ExtractPrimalFuncContext auto oldIntermediateParam = func->getLastParam(); auto outIntermediary = builder.emitParam(builder.getInOutType((IRType*)intermediateType)); + primalParams.Add(outIntermediary); oldIntermediateParam->replaceUsesWith(outIntermediary); oldIntermediateParam->removeAndDeallocate(); @@ -368,62 +370,18 @@ struct ExtractPrimalFuncContext auto defVal = builder.emitDefaultConstructRaw((IRType*)intermediateType); builder.emitStore(outIntermediary, defVal); - // The primal func will not have the result derivative param (second to last param), so we remove it. - if (isResultDifferentiable) - { - auto resultDerivativeParam = func->getLastParam()->getPrevParam(); - SLANG_RELEASE_ASSERT(!resultDerivativeParam->hasUses()); - resultDerivativeParam->removeAndDeallocate(); - } - - // Finally, go through parameters and translate their type back to primal type. + // Remove any parameters not in `primalParams` set. + List<IRInst*> params; for (auto param = func->getFirstParam(); param;) { - auto next = param->getNextParam(); - [this, firstBlock, &builder, param]() + auto nextParam = param->getNextParam(); + if (!primalParams.Contains(param)) { - for (auto use = param->firstUse; use; use = use->nextUse) - { - if (use->getUser()->getOp() == kIROp_AutoDiffOriginalValueDecoration) - { - use->getUser()->getParent()->replaceUsesWith(param); - return; - } - else if (use->getUser()->getOp() == kIROp_OutParamReverseGradientDecoration) - { - // This is a propagate func specific parameter, we should remove it. - SLANG_RELEASE_ASSERT(!param->hasMoreThanOneUse()); - param->removeAndDeallocate(); - return; - } - } - - IRInst* valueType = param->getDataType(); - auto inoutType = as<IRPtrTypeBase>(param->getDataType()); - if (inoutType) valueType = inoutType->getValueType(); - auto diffPairType = as<IRDifferentialPairType>(valueType); - if (!diffPairType) - return; - - builder.setInsertBefore(firstBlock->getFirstOrdinaryInst()); - - auto originalValueType = diffPairType->getValueType(); - - // Create a local var to act as the old param. - auto tempVar = builder.emitVar(diffPairType); - param->replaceUsesWith(tempVar); - auto pairValue = builder.emitMakeDifferentialPair( - diffPairType, - param, - backwardPrimalTranscriber->getDifferentialZeroOfType(&builder, originalValueType)); - builder.emitStore(tempVar, pairValue); - - // Change the param type to original type. - param->setFullType(originalValueType); - }(); - param = next; + param->replaceUsesWith(builder.getVoidValue()); + param->removeAndDeallocate(); + } + param = nextParam; } - return unzippedFunc; } }; @@ -446,7 +404,10 @@ static void copyPrimalValueStructKeyDecorations(IRInst* inst, IRCloneEnv& cloneE } IRFunc* DiffUnzipPass::extractPrimalFunc( - IRFunc* func, IRFunc* originalFunc, bool isResultDifferentiable, IRInst*& intermediateType) + IRFunc* func, + IRFunc* originalFunc, + ParameterBlockTransposeInfo& paramInfo, + IRInst*& intermediateType) { IRBuilder builder(this->autodiffContext->sharedBuilder); builder.setInsertBefore(func); @@ -456,11 +417,31 @@ IRFunc* DiffUnzipPass::extractPrimalFunc( subEnv.parent = &cloneEnv; auto clonedFunc = as<IRFunc>(cloneInst(&subEnv, &builder, func)); + // Remove [KeepAlive] decorations in clonedFunc. + for (auto block : clonedFunc->getBlocks()) + for (auto inst : block->getChildren()) + if (auto decor = inst->findDecoration<IRKeepAliveDecoration>()) + decor->removeAndDeallocate(); + + // Remove propagate func specific primal insts from cloned func. + for (auto inst : paramInfo.propagateFuncSpecificPrimalInsts) + { + auto newInst = subEnv.mapOldValToNew[inst].GetValue(); + newInst->removeAndDeallocate(); + } + + HashSet<IRInst*> newPrimalParams; + for (auto param : func->getParams()) + { + if (paramInfo.primalFuncParams.Contains(param)) + newPrimalParams.Add(subEnv.mapOldValToNew[param].GetValue()); + } + ExtractPrimalFuncContext context; context.init(autodiffContext->sharedBuilder, autodiffContext->transcriberSet.primalTranscriber); intermediateType = nullptr; - auto primalFunc = context.turnUnzippedFuncIntoPrimalFunc(clonedFunc, originalFunc, isResultDifferentiable, intermediateType); + auto primalFunc = context.turnUnzippedFuncIntoPrimalFunc(clonedFunc, originalFunc, newPrimalParams, intermediateType); if (auto nameHint = primalFunc->findDecoration<IRNameHintDecoration>()) { @@ -489,21 +470,26 @@ IRFunc* DiffUnzipPass::extractPrimalFunc( if (auto structKeyDecor = inst->findDecoration<IRPrimalValueStructKeyDecoration>()) { builder.setInsertBefore(inst); - auto val = builder.emitFieldExtract( - inst->getDataType(), - intermediateVar, - structKeyDecor->getStructKey()); if (inst->getOp() == kIROp_Var) { // This is a var for intermediate context. + auto valType = cast<IRPtrTypeBase>(inst->getFullType())->getValueType(); + auto val = builder.emitFieldExtract( + valType, + intermediateVar, + structKeyDecor->getStructKey()); auto tempVar = - builder.emitVar(cast<IRPtrTypeBase>(inst->getFullType())->getValueType()); + builder.emitVar(valType); builder.emitStore(tempVar, val); inst->replaceUsesWith(tempVar); } else { // Orindary value. + auto val = builder.emitFieldExtract( + inst->getFullType(), + intermediateVar, + structKeyDecor->getStructKey()); inst->replaceUsesWith(val); } instsToRemove.add(inst); @@ -522,6 +508,8 @@ IRFunc* DiffUnzipPass::extractPrimalFunc( { inst->removeAndDeallocate(); } + + stripTempDecorations(func); // Run simplification to DCE unnecessary insts. eliminateDeadCode(func); diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h index 3055d057b..057ff53c4 100644 --- a/source/slang/slang-ir-autodiff-unzip.h +++ b/source/slang/slang-ir-autodiff-unzip.h @@ -14,6 +14,8 @@ namespace Slang { +struct ParameterBlockTransposeInfo; + struct DiffUnzipPass { AutoDiffSharedContext* autodiffContext; @@ -120,7 +122,7 @@ struct DiffUnzipPass return diffMap[inst]; } - IRFunc* unzipDiffInsts(IRFunc* func) + void unzipDiffInsts(IRFunc* func) { diffTypeContext.setFunc(func); @@ -129,21 +131,33 @@ struct DiffUnzipPass IRBuilder* builder = &builderStorage; - // Clone the entire function. - // TODO: Maybe don't clone? The reverse-mode process seems to clone several times. - // TODO: Looks like we get a copy of the decorations? - IRCloneEnv subEnv; - subEnv.parent = &cloneEnv; - builder->setInsertBefore(func); - IRFunc* unzippedFunc = as<IRFunc>(cloneInst(&subEnv, builder, func)); + IRFunc* unzippedFunc = func; - builder->setInsertInto(unzippedFunc); + // Initialize the primal/diff map for parameters. + // Generate distinct references for parameters that should be split. + // We don't actually modify the parameter list here, instead we emit + // PrimalParamRef(param) and DiffParamRef(param) and use those to represent + // a use from the primal or diff part of the program. + builder->setInsertBefore(unzippedFunc->getFirstBlock()->getTerminator()); - auto originalParam = func->getFirstParam(); for (auto primalParam = unzippedFunc->getFirstParam(); primalParam; primalParam = primalParam->getNextParam()) { - primalMap[originalParam] = primalParam; - originalParam = originalParam->getNextParam(); + auto type = primalParam->getFullType(); + if (auto ptrType = as<IRPtrTypeBase>(type)) + { + type = ptrType->getValueType(); + } + if (auto pairType = as<IRDifferentialPairType>(type)) + { + IRInst* diffType = diffTypeContext.getDifferentialTypeFromDiffPairType(builder, pairType); + if (as<IRPtrTypeBase>(primalParam->getFullType())) + diffType = builder->getPtrType(primalParam->getFullType()->getOp(), (IRType*)diffType); + auto primalRef = builder->emitPrimalParamRef(primalParam); + auto diffRef = builder->emitDiffParamRef((IRType*)diffType, primalParam); + builder->markInstAsDifferential(diffRef, pairType->getValueType()); + primalMap[primalParam] = primalRef; + diffMap[primalParam] = diffRef; + } } // Functions need to have at least two blocks at this point (one for parameters, @@ -239,8 +253,6 @@ struct DiffUnzipPass // Remove old blocks. for (auto block : mixedBlocks) block->removeAndDeallocate(); - - return unzippedFunc; } IRBlock* getInitializerBlock(IndexedRegion* region) @@ -476,25 +488,12 @@ struct DiffUnzipPass } } - IRFunc* extractPrimalFunc(IRFunc* func, IRFunc* originalFunc, bool isResultDifferentiable, IRInst*& intermediateType); - - bool isRelevantDifferentialPair(IRType* type) - { - if (as<IRDifferentialPairType>(type)) - { - return true; - } - else if (auto argPtrType = as<IRPtrTypeBase>(type)) - { - if (as<IRDifferentialPairType>(argPtrType->getValueType())) - { - return true; - } - } - - return false; - } - + IRFunc* extractPrimalFunc( + IRFunc* func, + IRFunc* originalFunc, + ParameterBlockTransposeInfo& paramInfo, + IRInst*& intermediateType); + static IRInst* _getOriginalFunc(IRInst* call) { if (auto decor = call->findDecoration<IRAutoDiffOriginalValueDecoration>()) @@ -606,7 +605,13 @@ struct DiffUnzipPass } else if (auto inoutType = as<IRInOutType>(primalParamType)) { - SLANG_UNIMPLEMENTED_X("nested call inout parameter"); + // Since arg is split into separate vars, we need a new temp var that represents + // the remerged diff pair. + auto diffPairType = as<IRDifferentialPairType>(as<IRPtrTypeBase>(arg->getDataType())->getValueType()); + auto primalValueType = diffPairType->getValueType(); + auto diffPairRef = diffBuilder->emitReverseGradientDiffPairRef(arg->getDataType(), primalArg, diffArg); + diffBuilder->markInstAsDifferential(diffPairRef, primalValueType); + diffArgs.add(diffPairRef); } else { @@ -661,20 +666,6 @@ struct DiffUnzipPass InstPair splitLoad(IRBuilder* primalBuilder, IRBuilder* diffBuilder, IRLoad* mixedLoad) { - if (auto param = as<IRParam>(mixedLoad->getPtr())) - { - auto diffPairPtrType = as<IRPtrTypeBase>(param->getFullType()); - SLANG_RELEASE_ASSERT(diffPairPtrType); - auto diffPairType = as<IRDifferentialPairType>(diffPairPtrType->getValueType()); - SLANG_RELEASE_ASSERT(diffPairType); - auto diffType = (IRType*)diffTypeContext.getDifferentialTypeFromDiffPairType(diffBuilder, diffPairType); - auto loadedParam = primalBuilder->emitLoad(param); - return InstPair( - primalBuilder->emitDifferentialPairGetPrimal(loadedParam), - primalBuilder->emitDifferentialPairGetDifferential(diffType, loadedParam)); - } - - // Everything else should have already been split. auto primalPtr = lookupPrimalInst(mixedLoad->getPtr()); auto diffPtr = lookupDiffInst(mixedLoad->getPtr()); auto primalVal = primalBuilder->emitLoad(primalPtr); @@ -685,22 +676,14 @@ struct DiffUnzipPass InstPair splitStore(IRBuilder* primalBuilder, IRBuilder* diffBuilder, IRStore* mixedStore) { - // We will only generate mixed store to parameters. - if (!as<IRParam>(mixedStore->getPtr())) - { - SLANG_UNIMPLEMENTED_X("Splitting a store that is not writing to a param."); - } - - auto primalAddr = mixedStore->getPtr(); + auto primalAddr = lookupPrimalInst(mixedStore->getPtr()); + auto diffAddr = lookupDiffInst(mixedStore->getPtr()); auto primalVal = lookupPrimalInst(mixedStore->getVal()); auto diffVal = lookupDiffInst(mixedStore->getVal()); - // For now the param type and value type will not type-check in these store insts, - // but the param inst will be changed to the correct type after we synthesize primal and - // propagate func. auto primalStore = primalBuilder->emitStore(primalAddr, primalVal); - auto diffStore = diffBuilder->emitStore(primalAddr, diffVal); + auto diffStore = diffBuilder->emitStore(diffAddr, diffVal); diffBuilder->markInstAsDifferential(diffStore, primalVal->getFullType()); return InstPair(primalStore, diffStore); diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index 7a2e8c75e..fdaff4960 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -496,6 +496,8 @@ void stripTempDecorations(IRInst* inst) case kIROp_DifferentialInstDecoration: case kIROp_MixedDifferentialInstDecoration: case kIROp_AutoDiffOriginalValueDecoration: + case kIROp_BackwardDerivativePrimalReturnDecoration: + case kIROp_PrimalValueStructKeyDecoration: decor->removeAndDeallocate(); break; default: diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h index 30f053673..0662cf846 100644 --- a/source/slang/slang-ir-autodiff.h +++ b/source/slang/slang-ir-autodiff.h @@ -286,6 +286,7 @@ struct DifferentialPairTypeBuilder }; void stripAutoDiffDecorations(IRModule* module); +void stripTempDecorations(IRInst* inst); bool isNoDiffType(IRType* paramType); @@ -309,4 +310,20 @@ bool isBackwardDifferentiableFunc(IRInst* func); bool isDifferentiableType(DifferentiableTypeConformanceContext& context, IRInst* typeInst); +inline bool isRelevantDifferentialPair(IRType* type) +{ + if (as<IRDifferentialPairType>(type)) + { + return true; + } + else if (auto argPtrType = as<IRPtrTypeBase>(type)) + { + if (as<IRDifferentialPairType>(argPtrType->getValueType())) + { + return true; + } + } + return false; +} + }; diff --git a/source/slang/slang-ir-clone.cpp b/source/slang/slang-ir-clone.cpp index 5b5ace64b..dbeb1e934 100644 --- a/source/slang/slang-ir-clone.cpp +++ b/source/slang/slang-ir-clone.cpp @@ -279,6 +279,7 @@ IRInst* cloneInst( } void cloneDecoration( + IRCloneEnv* cloneEnv, IRDecoration* oldDecoration, IRInst* newParent, IRModule* module) @@ -292,6 +293,7 @@ void cloneDecoration( builder.setInsertInto(newParent); IRCloneEnv env; + env.parent = cloneEnv; cloneInst(&env, &builder, oldDecoration); } @@ -300,6 +302,7 @@ void cloneDecoration( IRInst* newParent) { cloneDecoration( + nullptr, oldDecoration, newParent, newParent->getModule()); diff --git a/source/slang/slang-ir-clone.h b/source/slang/slang-ir-clone.h index 824806d57..f4f53ff92 100644 --- a/source/slang/slang-ir-clone.h +++ b/source/slang/slang-ir-clone.h @@ -133,6 +133,7 @@ IRInst* cloneInst( /// Uses `module` to allocate any new instructions. /// void cloneDecoration( + IRCloneEnv* parentEnv, IRDecoration* oldDecoration, IRInst* newParent, IRModule* module); diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index e1143b7b9..1cb839751 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -330,6 +330,18 @@ INST(Store, store, 2, 0) // currently accumulated derivative to pass to some dOut argument in a nested call. INST(LoadReverseGradient, LoadReverseGradient, 1, 0) +// Produced and removed during backward auto-diff pass as a temporary placeholder containing the +// primal and accumulated derivative values to pass to an inout argument in a nested call. +INST(ReverseGradientDiffPairRef, ReverseGradientDiffPairRef, 2, 0) + +// Produced and removed during backward auto-diff pass. This inst is generated by the splitting step +// to represent a reference to an inout parameter for use in the primal part of the computation. +INST(PrimalParamRef, PrimalParamRef, 1, 0) + +// Produced and removed during backward auto-diff pass. This inst is generated by the splitting step +// to represent a reference to an inout parameter for use in the back-prop part of the computation. +INST(DiffParamRef, DiffParamRef, 1, 0) + INST(FieldExtract, get_field, 2, 0) INST(FieldAddress, get_field_addr, 2, 0) @@ -771,12 +783,6 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) /// forward-differentiated updateElement inst. INST(PrimalElementTypeDecoration, primalElementType, 1, 0) - /// Used by the auto-diff pass. An `out T` parameter will transcribe to a `in T.Differential` parameter. - /// We will also create a temp var of type `T.Differential` in the function body so the `load` and `stores` - /// can operand on a valid address. We use this decoration to associate this temp var with its corresponding - /// input parameter. - INST(OutParamReverseGradientDecoration, outParamRevGrad, 1, 0) - /// Used by the auto-diff pass to hold a reference to a /// differential member of a type in its associated differential type. INST(DerivativeMemberDecoration, derivativeMemberDecoration, 1, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 132a96f16..d1374477f 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -723,18 +723,6 @@ struct IRMixedDifferentialInstDecoration : IRDecoration IRType* getPairType() { return as<IRType>(getOperand(0)); } }; -struct IROutParamReverseGradientDecoration : IRDecoration -{ - enum - { - kOp = kIROp_OutParamReverseGradientDecoration - }; - - IR_LEAF_ISA(OutParamReverseGradientDecoration) - - IRInst* getValue() { return getOperand(0); } -}; - struct IRBackwardDifferentiableDecoration : IRDecoration { enum @@ -1782,12 +1770,31 @@ struct IRGetElementPtr : IRInst IRInst* getIndex() { return getOperand(1); } }; -struct IRLoadReverseGradient :IRInst +struct IRLoadReverseGradient : IRInst { IR_LEAF_ISA(LoadReverseGradient) IRInst* getValue() { return getOperand(0); } }; +struct IRReverseGradientDiffPairRef : IRInst +{ + IR_LEAF_ISA(ReverseGradientDiffPairRef) + IRInst* getPrimal() { return getOperand(0); } + IRInst* getDiff() { return getOperand(1); } +}; + +struct IRPrimalParamRef : IRInst +{ + IR_LEAF_ISA(PrimalParamRef) + IRInst* getReferencedParam() { return getOperand(0); } +}; + +struct IRDiffParamRef : IRInst +{ + IR_LEAF_ISA(DiffParamRef) + IRInst* getReferencedParam() { return getOperand(0); } +}; + struct IRGetNativePtr : IRInst { IR_LEAF_ISA(GetNativePtr); @@ -3145,6 +3152,9 @@ public: IRInst* ptr); IRInst* emitLoadReverseGradient(IRType* type, IRInst* diffValue); + IRInst* emitReverseGradientDiffPairRef(IRType* type, IRInst* primalVar, IRInst* diffVar); + IRInst* emitPrimalParamRef(IRInst* param); + IRInst* emitDiffParamRef(IRType* type, IRInst* param); IRInst* emitStore( IRInst* dstPtr, diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp index 747d0ccdd..cf7acd46c 100644 --- a/source/slang/slang-ir-specialize.cpp +++ b/source/slang/slang-ir-specialize.cpp @@ -392,13 +392,14 @@ struct SpecializationContext { for (auto decor : genericReturnVal->getDecorations()) { + bool specialized = false; if (decor->getOp() == kIROp_ForwardDerivativeDecoration || decor->getOp() == kIROp_UserDefinedBackwardDerivativeDecoration) { // If we already have a diff func on this specialize, skip. if (auto specDiffRef = specInst->findDecorationImpl(decor->getOp())) { - return false; + continue; } auto specDiffFunc = as<IRSpecialize>(decor->getOperand(0)); @@ -443,8 +444,10 @@ struct SpecializationContext builder.addDecoration(specInst, decor->getOp(), newDiffFunc); - return true; + specialized = true; } + if (specialized) + return true; } } return false; diff --git a/source/slang/slang-ir-ssa.cpp b/source/slang/slang-ir-ssa.cpp index ee55a6546..3f250e31e 100644 --- a/source/slang/slang-ir-ssa.cpp +++ b/source/slang/slang-ir-ssa.cpp @@ -386,7 +386,7 @@ static void cloneRelevantDecorations( // if( !val->findDecorationImpl(decoration->getOp()) ) { - cloneDecoration(decoration, val, var->getModule()); + cloneDecoration(nullptr, decoration, val, var->getModule()); } break; } diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 4814726cf..558574bf6 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -4234,6 +4234,50 @@ namespace Slang return inst; } + IRInst* IRBuilder::emitReverseGradientDiffPairRef(IRType* type, IRInst* primalVar, IRInst* diffVar) + { + auto inst = createInst<IRReverseGradientDiffPairRef>( + this, + kIROp_ReverseGradientDiffPairRef, + type, + primalVar, + diffVar); + + addInst(inst); + return inst; + } + + IRInst* IRBuilder::emitPrimalParamRef(IRInst* param) + { + auto type = param->getFullType(); + auto ptrType = as<IRPtrTypeBase>(type); + auto valueType = type; + if (ptrType) valueType = ptrType->getValueType(); + auto pairType = as<IRDifferentialPairType>(valueType); + IRType* finalType = pairType->getValueType(); + if (ptrType) finalType = getPtrType(ptrType->getOp(), finalType); + auto inst = createInst<IRPrimalParamRef>( + this, + kIROp_PrimalParamRef, + finalType, + param); + + addInst(inst); + return inst; + } + + IRInst* IRBuilder::emitDiffParamRef(IRType* type, IRInst* param) + { + auto inst = createInst<IRDiffParamRef>( + this, + kIROp_DiffParamRef, + type, + param); + + addInst(inst); + return inst; + } + IRInst* IRBuilder::emitLoad( IRType* type, IRInst* ptr) @@ -6753,6 +6797,8 @@ namespace Slang // common subexpression elimination, etc. // auto call = cast<IRCall>(this); + if (call->findDecoration<IRNoSideEffectDecoration>()) + return false; return !isPureFunctionalCall(call); } break; @@ -6809,10 +6855,14 @@ namespace Slang case kIROp_MakeOptionalNone: case kIROp_OptionalHasValue: case kIROp_GetOptionalValue: + case kIROp_DifferentialPairGetPrimal: + case kIROp_DifferentialPairGetDifferential: + case kIROp_MakeDifferentialPair: case kIROp_MakeTuple: case kIROp_GetTupleElement: case kIROp_Load: // We are ignoring the possibility of loads from bad addresses, or `volatile` loads case kIROp_LoadReverseGradient: + case kIROp_ReverseGradientDiffPairRef: case kIROp_ImageSubscript: case kIROp_FieldExtract: case kIROp_FieldAddress: |
