From bbd1e1786401bb88c34802b987d4da72e2364503 Mon Sep 17 00:00:00 2001 From: Yong He Date: Wed, 1 Feb 2023 14:18:57 -0800 Subject: Support `out` parameters in backward differentiation. (#2619) * Support `out` parameters in backward differentiation. * Fixes. * Fix cleanup. --------- Co-authored-by: Yong He --- source/slang/slang-ir-addr-inst-elimination.cpp | 31 +- source/slang/slang-ir-autodiff-fwd.cpp | 215 +++++++- source/slang/slang-ir-autodiff-fwd.h | 6 + source/slang/slang-ir-autodiff-rev.cpp | 557 +++++++++------------ source/slang/slang-ir-autodiff-rev.h | 18 +- .../slang/slang-ir-autodiff-transcriber-base.cpp | 65 +-- source/slang/slang-ir-autodiff-transcriber-base.h | 2 + source/slang/slang-ir-autodiff-transpose.h | 75 ++- source/slang/slang-ir-autodiff-unzip.cpp | 82 +-- source/slang/slang-ir-autodiff-unzip.h | 176 +++++-- source/slang/slang-ir-autodiff.cpp | 10 +- source/slang/slang-ir-autodiff.h | 2 + source/slang/slang-ir-check-differentiability.cpp | 30 +- source/slang/slang-ir-init-local-var.cpp | 34 ++ source/slang/slang-ir-init-local-var.h | 14 + source/slang/slang-ir-inst-defs.h | 10 + source/slang/slang-ir-insts.h | 23 +- source/slang/slang-ir-util.cpp | 4 +- source/slang/slang-ir.cpp | 35 +- 19 files changed, 835 insertions(+), 554 deletions(-) create mode 100644 source/slang/slang-ir-init-local-var.cpp create mode 100644 source/slang/slang-ir-init-local-var.h (limited to 'source/slang') diff --git a/source/slang/slang-ir-addr-inst-elimination.cpp b/source/slang/slang-ir-addr-inst-elimination.cpp index a5e0e0a4e..a451e24a5 100644 --- a/source/slang/slang-ir-addr-inst-elimination.cpp +++ b/source/slang/slang-ir-addr-inst-elimination.cpp @@ -54,11 +54,18 @@ struct AddressInstEliminationContext } endLoop:; auto lastAddr = accessChain.getLast(); - auto lastVal = builder.emitLoad(lastAddr); accessChain.removeLast(); accessChain.reverse(); - auto update = builder.emitUpdateElement(lastVal, accessChain, val); - builder.emitStore(lastAddr, update); + if (accessChain.getCount()) + { + auto lastVal = builder.emitLoad(lastAddr); + auto update = builder.emitUpdateElement(lastVal, accessChain, val); + builder.emitStore(lastAddr, update); + } + else + { + builder.emitStore(lastAddr, val); + } } void transformLoadAddr(IRUse* use) @@ -92,7 +99,22 @@ struct AddressInstEliminationContext IRBuilder builder(sharedBuilder); builder.setInsertBefore(call); auto tempVar = builder.emitVar(cast(addr->getFullType())->getValueType()); - builder.emitStore(tempVar, getValue(builder, addr)); + auto callee = getResolvedInstForDecorations(call->getCallee()); + auto funcType = as(callee->getFullType()); + SLANG_RELEASE_ASSERT(funcType); + UInt paramIndex = (UInt)(use - call->getOperands() - 1); + SLANG_RELEASE_ASSERT(call->getArg(paramIndex) == addr); + if (!as(funcType->getParamType(paramIndex))) + { + builder.emitStore(tempVar, getValue(builder, addr)); + } + else + { + builder.emitStore( + tempVar, + builder.emitDefaultConstruct( + as(tempVar->getDataType())->getValueType())); + } builder.setInsertAfter(call); storeValue(builder, addr, builder.emitLoad(tempVar)); use->set(tempVar); @@ -170,4 +192,5 @@ SlangResult eliminateAddressInsts( AddressInstEliminationContext ctx; return ctx.eliminateAddressInstsImpl(sharedBuilder, policy, func, sink); } + } // namespace Slang diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index f60412efb..a9e716ce4 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -7,6 +7,7 @@ #include "slang-ir-eliminate-phis.h" #include "slang-ir-util.h" #include "slang-ir-inst-pass-base.h" +#include "slang-ir-single-return.h" namespace Slang { @@ -232,6 +233,8 @@ InstPair ForwardDiffTranscriber::transcribeStore(IRBuilder* builder, IRStore* or builder->markInstAsMixedDifferential(diffStoreVal, diffPairType); auto store = builder->emitStore(primalStoreLocation, valToStore); + builder->markInstAsMixedDifferential(store, diffPairType); + return InstPair(store, nullptr); } } @@ -385,12 +388,18 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig SLANG_ASSERT(calleeType); SLANG_RELEASE_ASSERT(calleeType->getParamCount() == origCall->getArgCount()); + auto placeholderCall = builder->emitCallInst(nullptr, builder->emitUndefined(builder->getTypeKind()), 0, nullptr); + builder->setInsertBefore(placeholderCall); + IRBuilder argBuilder = *builder; + IRBuilder afterBuilder = argBuilder; + afterBuilder.setInsertAfter(placeholderCall); + List args; // Go over the parameter list and create pairs for each input (if required) for (UIndex ii = 0; ii < origCall->getArgCount(); ii++) { auto origArg = origCall->getArg(ii); - auto primalArg = findOrTranscribePrimalInst(builder, origArg); + auto primalArg = findOrTranscribePrimalInst(&argBuilder, origArg); SLANG_ASSERT(primalArg); auto primalType = primalArg->getDataType(); @@ -402,20 +411,71 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig while (auto attrType = as(primalType)) primalType = attrType->getBaseType(); } - if (auto pairType = tryGetDiffPairType(builder, primalType)) + if (auto pairType = tryGetDiffPairType(&argBuilder, primalType)) { - auto diffArg = findOrTranscribeDiffInst(builder, origArg); - if (!diffArg) - diffArg = getDifferentialZeroOfType(builder, primalType); + auto pairPtrType = as(pairType); + auto pairValType = as( + pairPtrType ? pairPtrType->getValueType() : pairType); + auto diffType = differentiableTypeConformanceContext.getDifferentialTypeFromDiffPairType(&argBuilder, pairValType); + if (auto ptrParamType = as(paramType)) + { + // Create temp var to pass in/out arguments. + auto srcVar = argBuilder.emitVar(ptrParamType->getValueType()); + argBuilder.markInstAsMixedDifferential(srcVar, pairValType->getValueType()); + + auto diffArg = findOrTranscribeDiffInst(&argBuilder, origArg); + if (ptrParamType->getOp() == kIROp_InOutType) + { + // Set initial value. + auto primalVal = argBuilder.emitLoad(primalArg); + auto diffArgVal = diffArg; + if (!diffArg) + diffArgVal = getDifferentialZeroOfType(builder, (IRType*)pairValType->getValueType()); + else + { + diffArgVal = argBuilder.emitLoad(diffArg); + argBuilder.markInstAsDifferential(diffArgVal, pairValType->getValueType()); + } + auto initVal = argBuilder.emitMakeDifferentialPair(pairValType, primalVal, diffArgVal); + argBuilder.markInstAsMixedDifferential(initVal, primalType); + auto store = argBuilder.emitStore(srcVar, initVal); + argBuilder.markInstAsMixedDifferential(store, primalType); + } + if (as(ptrParamType)) + { + // Read back new value. + auto newVal = afterBuilder.emitLoad(srcVar); + afterBuilder.markInstAsMixedDifferential(newVal, pairValType->getValueType()); + auto newPrimalVal = afterBuilder.emitDifferentialPairGetPrimal(newVal); + afterBuilder.emitStore(primalArg, newPrimalVal); + + if (diffArg) + { + auto newDiffVal = afterBuilder.emitDifferentialPairGetDifferential((IRType*)diffType, newVal); + afterBuilder.markInstAsDifferential(newDiffVal, pairValType->getValueType()); + auto storeInst = afterBuilder.emitStore(diffArg, newDiffVal); + afterBuilder.markInstAsDifferential(storeInst, pairValType->getValueType()); + } + } + args.add(srcVar); + continue; + } + else + { + auto diffArg = findOrTranscribeDiffInst(&argBuilder, origArg); + if (!diffArg) + diffArg = getDifferentialZeroOfType(&argBuilder, primalType); - // If a pair type can be formed, this must be non-null. - SLANG_RELEASE_ASSERT(diffArg); - - auto diffPair = builder->emitMakeDifferentialPair(pairType, primalArg, diffArg); - builder->markInstAsMixedDifferential(diffPair, pairType); + // If a pair type can be formed, this must be non-null. + SLANG_RELEASE_ASSERT(diffArg); - args.add(diffPair); - continue; + auto diffPair = argBuilder.emitMakeDifferentialPair(pairType, primalArg, diffArg); + argBuilder.markInstAsMixedDifferential(diffPair, pairType); + + args.add(diffPair); + continue; + } + } } // Argument is not differentiable. @@ -424,26 +484,29 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig } IRType* diffReturnType = nullptr; - diffReturnType = tryGetDiffPairType(builder, origCall->getFullType()); + diffReturnType = tryGetDiffPairType(&argBuilder, origCall->getFullType()); if (!diffReturnType) { SLANG_RELEASE_ASSERT(origCall->getFullType()->getOp() == kIROp_VoidType); - diffReturnType = builder->getVoidType(); + diffReturnType = argBuilder.getVoidType(); } - auto callInst = builder->emitCallInst( + auto callInst = argBuilder.emitCallInst( diffReturnType, diffCallee, args); - builder->markInstAsMixedDifferential(callInst, diffReturnType); - builder->addAutoDiffOriginalValueDecoration(callInst, primalCallee); + placeholderCall->removeAndDeallocate(); + argBuilder.markInstAsMixedDifferential(callInst, diffReturnType); + argBuilder.addAutoDiffOriginalValueDecoration(callInst, primalCallee); + + *builder = afterBuilder; if (diffReturnType->getOp() != kIROp_VoidType) { - IRInst* primalResultValue = builder->emitDifferentialPairGetPrimal(callInst); - auto diffType = differentiateType(builder, origCall->getFullType()); - IRInst* diffResultValue = builder->emitDifferentialPairGetDifferential(diffType, callInst); + IRInst* primalResultValue = afterBuilder.emitDifferentialPairGetPrimal(callInst); + auto diffType = differentiateType(&afterBuilder, origCall->getFullType()); + IRInst* diffResultValue = afterBuilder.emitDifferentialPairGetDifferential(diffType, callInst); return InstPair(primalResultValue, diffResultValue); } else @@ -1150,6 +1213,8 @@ InstPair ForwardDiffTranscriber::transcribeFunc(IRBuilder* inBuilder, IRFunc* pr builder.setInsertInto(diffFunc); differentiableTypeConformanceContext.setFunc(primalFunc); + + mapInOutParamToWriteBackValue.Clear(); // Transcribe children from origFunc into diffFunc for (auto block = primalFunc->getFirstBlock(); block; block = block->getNextBlock()) @@ -1160,6 +1225,43 @@ InstPair ForwardDiffTranscriber::transcribeFunc(IRBuilder* inBuilder, IRFunc* pr for (auto block = primalFunc->getFirstBlock(); block; block = block->getNextBlock()) as(lookupDiffInst(block))->insertAtEnd(diffFunc); + for (auto block : diffFunc->getBlocks()) + { + for (auto inst : block->getChildren()) + { + if (inst->getOp() == kIROp_Return) + { + // Insert write backs to mutable parameters before returning. + builder.setInsertBefore(inst); + for (auto& writeBack : mapInOutParamToWriteBackValue) + { + auto param = writeBack.Key; + auto primalVal = builder.emitLoad(writeBack.Value.primal); + IRInst* valToStore = nullptr; + if (writeBack.Value.differential) + { + auto diffVal = builder.emitLoad(writeBack.Value.differential); + builder.markInstAsDifferential(diffVal, primalVal->getFullType()); + valToStore = builder.emitMakeDifferentialPair(cast(param->getFullType())->getValueType(), + primalVal, diffVal); + builder.markInstAsMixedDifferential(valToStore, valToStore->getFullType()); + } + else + { + valToStore = builder.emitLoad(writeBack.Value.primal); + } + + auto storeInst = builder.emitStore(param, valToStore); + + if (writeBack.Value.differential) + { + builder.markInstAsMixedDifferential(storeInst, valToStore->getFullType()); + } + } + } + } + } + return InstPair(primalFunc, diffFunc); } @@ -1297,4 +1399,77 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* return InstPair(nullptr, nullptr); } +String ForwardDiffTranscriber::makeDiffPairName(IRInst* origVar) +{ + if (auto namehintDecoration = origVar->findDecoration()) + { + return ("dp" + String(namehintDecoration->getName())); + } + + return String(""); +} + +InstPair ForwardDiffTranscriber::transcribeFuncParam(IRBuilder* builder, IRParam* origParam, IRInst* primalType) +{ + if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)primalType)) + { + IRInst* diffPairParam = builder->emitParam(diffPairType); + + auto diffPairVarName = makeDiffPairName(origParam); + if (diffPairVarName.getLength() > 0) + builder->addNameHintDecoration(diffPairParam, diffPairVarName.getUnownedSlice()); + + SLANG_ASSERT(diffPairParam); + + if (auto pairType = as(diffPairType)) + { + return InstPair( + builder->emitDifferentialPairGetPrimal(diffPairParam), + builder->emitDifferentialPairGetDifferential( + (IRType*)pairBuilder->getDiffTypeFromPairType(builder, pairType), + diffPairParam)); + } + else if (auto pairPtrType = as(diffPairType)) + { + auto ptrInnerPairType = as(pairPtrType->getValueType()); + // Make a local copy of the parameter for primal and diff parts. + auto primal = builder->emitVar(ptrInnerPairType->getValueType()); + auto diffType = differentiateType(builder, cast(origParam->getDataType())->getValueType()); + auto diff = builder->emitVar(diffType); + + IRInst* primalInitVal = nullptr; + IRInst* diffInitVal = nullptr; + if (as(diffPairType)) + { + primalInitVal = builder->emitDefaultConstruct(ptrInnerPairType->getValueType()); + diffInitVal = builder->emitDefaultConstructRaw(diffType); + } + else + { + auto initVal = builder->emitLoad(diffPairParam); + primalInitVal = builder->emitDifferentialPairGetPrimal(initVal); + diffInitVal = builder->emitDifferentialPairGetDifferential(diffType, initVal); + } + builder->markInstAsDifferential(diffInitVal, ptrInnerPairType->getValueType()); + + builder->emitStore(primal, primalInitVal); + + auto diffStore = builder->emitStore(diff, diffInitVal); + builder->markInstAsDifferential(diffStore, ptrInnerPairType->getValueType()); + + mapInOutParamToWriteBackValue[diffPairParam] = InstPair(primal, diff); + return InstPair(primal, diff); + } + } + + auto primalInst = cloneInst(&cloneEnv, builder, origParam); + if (auto primalParam = as(primalInst)) + { + SLANG_RELEASE_ASSERT(builder->getInsertLoc().getBlock()); + primalParam->removeFromParent(); + builder->getInsertLoc().getBlock()->addParam(primalParam); + } + return InstPair(primalInst, nullptr); +} + } diff --git a/source/slang/slang-ir-autodiff-fwd.h b/source/slang/slang-ir-autodiff-fwd.h index e595191a3..260b0a433 100644 --- a/source/slang/slang-ir-autodiff-fwd.h +++ b/source/slang/slang-ir-autodiff-fwd.h @@ -8,11 +8,15 @@ namespace Slang struct ForwardDiffTranscriber : AutoDiffTranscriberBase { + // Pending values to write back to inout params at the end of the current function. + OrderedDictionary mapInOutParamToWriteBackValue; + ForwardDiffTranscriber(AutoDiffSharedContext* shared, SharedIRBuilder* inSharedBuilder, DiagnosticSink* inSink) : AutoDiffTranscriberBase(shared, inSharedBuilder, inSink) { } + // Returns "d" to use as a name hint for variables and parameters. // If no primal name is available, returns a blank string. // @@ -95,6 +99,8 @@ struct ForwardDiffTranscriber : AutoDiffTranscriberBase virtual InstPair transcribeInstImpl(IRBuilder* builder, IRInst* origInst) override; + virtual InstPair transcribeFuncParam(IRBuilder* builder, IRParam* origParam, IRInst* primalType) override; + virtual IROp getInterfaceRequirementDerivativeDecorationOp() override { return kIROp_ForwardDerivativeDecoration; diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index 0f2ceceb4..9c63a4012 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -11,6 +11,7 @@ #include "slang-ir-single-return.h" #include "slang-ir-addr-inst-elimination.h" #include "slang-ir-eliminate-multilevel-break.h" +#include "slang-ir-init-local-var.h" namespace Slang { @@ -21,32 +22,10 @@ namespace Slang for (UIndex i = 0; i < funcType->getParamCount(); i++) { - bool noDiff = false; auto origType = funcType->getParamType(i); - auto primalType = (IRType*)findOrTranscribePrimalInst(builder, origType); - - if (auto attrType = as(primalType)) - { - if (attrType->findAttr()) - { - noDiff = true; - primalType = attrType->getBaseType(); - } - } - if (noDiff) - { - newParameterTypes.add(primalType); - } - else - { - if (auto diffPairType = tryGetDiffPairType(builder, origType)) - { - auto inoutDiffPairType = builder->getPtrType(kIROp_InOutType, diffPairType); - newParameterTypes.add(inoutDiffPairType); - } - else - newParameterTypes.add(primalType); - } + auto paramType = transcribeParamTypeForPropagateFunc(builder, origType); + if (paramType) + newParameterTypes.add(paramType); } if (auto diffResultType = differentiateType(builder, funcType->getResultType())) @@ -75,7 +54,7 @@ namespace Slang for (UInt i = 0; i < funcType->getParamCount(); i++) { auto origType = funcType->getParamType(i); - auto primalType = (IRType*)findOrTranscribePrimalInst(builder, origType); + auto primalType = transcribeParamTypeForPrimalFunc(builder, origType); paramTypes.add(primalType); } paramTypes.add(outType); @@ -252,52 +231,57 @@ namespace Slang return String(""); } - InstPair BackwardDiffTranscriberBase::transposeBlock(IRBuilder* builder, IRBlock* origBlock) + static IRType* _getPrimalTypeFromNoDiffType(BackwardDiffTranscriberBase* transcriber, IRBuilder* builder, IRType* origType) { - IRBuilder subBuilder(builder->getSharedBuilder()); - subBuilder.setInsertLoc(builder->getInsertLoc()); + IRType* valueType = origType; + auto ptrType = as(valueType); + if (ptrType) + valueType = ptrType->getValueType(); - IRBlock* diffBlock = subBuilder.emitBlock(); - - subBuilder.setInsertInto(diffBlock); - - // First transcribe every parameter in the block. - for (auto param = origBlock->getFirstParam(); param; param = param->getNextParam()) - this->copyParam(&subBuilder, param); - - // The extra param for input gradient - auto gradParam = subBuilder.emitParam(as(origBlock->getParent()->getFullType())->getResultType()); - - // Then, run through every instruction and use the transcriber to generate the appropriate - // derivative code. - // - for (auto child = origBlock->getFirstOrdinaryInst(); child; child = child->getNextInst()) - this->copyInst(&subBuilder, child); - - auto lastInst = diffBlock->getLastOrdinaryInst(); - List grads = { gradParam }; - upperGradients.Add(lastInst, grads); - for (auto child = diffBlock->getLastOrdinaryInst(); child; child = child->getPrevInst()) + if (auto attrType = as(valueType)) { - auto upperGrads = upperGradients.TryGetValue(child); - if (!upperGrads) - continue; - if (upperGrads->getCount() > 1) + if (attrType->findAttr()) { - auto sumGrad = upperGrads->getFirst(); - for (auto i = 1; i < upperGrads->getCount(); i++) - { - sumGrad = subBuilder.emitAdd(sumGrad->getDataType(), sumGrad, (*upperGrads)[i]); - } - this->transposeInstBackward(&subBuilder, child, sumGrad); + auto primalValueType = (IRType*)transcriber->findOrTranscribePrimalInst(builder, valueType); + if (ptrType) + return builder->getPtrType(ptrType->getOp(), primalValueType); + return primalValueType; } - else - this->transposeInstBackward(&subBuilder, child, upperGrads->getFirst()); } + return nullptr; + } + + IRType* BackwardDiffTranscriberBase::transcribeParamTypeForPrimalFunc(IRBuilder* builder, IRType* paramType) + { + // If the param is marked as no_diff, return the primal type. + if (auto primalNoDiffType = _getPrimalTypeFromNoDiffType(this, builder, paramType)) + return primalNoDiffType; - subBuilder.emitReturn(); + return (IRType*)findOrTranscribePrimalInst(builder, paramType); + } - return InstPair(diffBlock, diffBlock); + IRType* BackwardDiffTranscriberBase::transcribeParamTypeForPropagateFunc(IRBuilder* builder, IRType* paramType) + { + if (auto outType = as(paramType)) + { + auto valueType = outType->getValueType(); + auto diffValueType = differentiateType(builder, valueType); + return diffValueType; + } + + // If the param is marked as no_diff, return the primal type. + if (auto primalNoDiffType = _getPrimalTypeFromNoDiffType(this, builder, paramType)) + return primalNoDiffType; + + auto diffPairType = tryGetDiffPairType(builder, paramType); + if (diffPairType) + { + if (!as(diffPairType)) + return builder->getInOutType(diffPairType); + return diffPairType; + } + auto primalType = (IRType*)findOrTranscribePrimalInst(builder, paramType); + return primalType; } // Create an empty func to represent the transcribed func of `origFunc`. @@ -387,39 +371,65 @@ namespace Slang IRBuilder builder(inBuilder->getSharedBuilder()); builder.setInsertInto(header.differential); builder.emitBlock(); - auto funcType = as(header.differential->getDataType()); + auto origFuncType = as(origFunc->getFullType()); List primalArgs, propagateArgs; List primalTypes, propagateTypes; - for (UInt i = 0; i < funcType->getParamCount(); i++) + for (UInt i = 0; i < origFuncType->getParamCount(); i++) { - auto paramType = (IRType*)findOrTranscribePrimalInst(&builder, funcType->getParamType(i)); - auto param = builder.emitParam(paramType); - if (i != funcType->getParamCount() - 1) + auto primalParamType = transcribeParamTypeForPrimalFunc(&builder, origFuncType->getParamType(i)); + auto propagateParamType = transcribeParamTypeForPropagateFunc(&builder, origFuncType->getParamType(i)); + if (propagateParamType) { - primalArgs.add(param); - } - propagateArgs.add(param); - propagateTypes.add(paramType); - } + auto param = builder.emitParam(propagateParamType); + propagateTypes.add(propagateParamType); + propagateArgs.add(param); - // Fetch primal values to use as arguments in primal func call. - for (auto& arg : primalArgs) - { - IRInst* valueType = arg->getDataType(); - auto inoutType = as(arg->getDataType()); - if (inoutType) + // Fetch primal values to use as arguments in primal func call. + IRInst* primalArg = param; + if (!as(primalParamType)) + { + // As long as the primal parameter is not an out type, + // we need to fetch the primal value from the parameter. + if (as(propagateParamType)) + { + primalArg = builder.emitLoad(param); + } + if (auto diffPairType = as(primalArg->getDataType())) + { + primalArg = builder.emitDifferentialPairGetPrimal(primalArg); + } + } + if (auto primalParamPtrType = as(primalParamType)) + { + // If primal parameter is mutable, we need to pass in a temp var. + auto tempVar = builder.emitVar(primalParamPtrType->getValueType()); + if (primalParamPtrType->getOp() == kIROp_InOutType) + { + // If the primal parameter is inout, we need to set the initial value. + builder.emitStore(tempVar, primalArg); + } + primalArgs.add(tempVar); + } + else + { + primalArgs.add(primalArg); + } + } + else { - valueType = inoutType->getValueType(); - arg = builder.emitLoad(arg); + auto var = builder.emitVar(primalParamType); + primalArgs.add(var); } - auto diffPairType = as(valueType); - if (!diffPairType) continue; - arg = builder.emitDifferentialPairGetPrimal(arg); + primalTypes.add(primalParamType); } - for (auto& arg : primalArgs) + // Add dOut argument to propagateArgs. + auto diffResultType = differentiateType(&builder, origFunc->getResultType()); + if (diffResultType) { - primalTypes.add(arg->getFullType()); + auto param = builder.emitParam(diffResultType); + propagateArgs.add(param); + propagateTypes.add(param->getFullType()); } auto outerGeneric = findOuterGeneric(origFunc); @@ -433,7 +443,6 @@ namespace Slang auto intermediateVar = builder.emitVar(intermediateType); - auto origFuncType = as(origFunc->getDataType()); auto primalFuncType = builder.getFuncType( primalTypes, origFuncType->getResultType()); @@ -486,6 +495,51 @@ namespace Slang builder.emitBranch(firstBlock); } + void insertTempVarForMutableParams(SharedIRBuilder* sharedBuilder, IRFunc* func) + { + IRBuilder builder(sharedBuilder); + auto firstBlock = func->getFirstBlock(); + builder.setInsertBefore(firstBlock->getFirstOrdinaryInst()); + + OrderedDictionary mapParamToTempVar; + List params; + for (auto param : firstBlock->getParams()) + { + if (auto ptrType = as(param->getDataType())) + { + params.add(param); + } + } + + for (auto param : params) + { + auto ptrType = as(param->getDataType()); + auto tempVar = builder.emitVar(ptrType->getValueType()); + mapParamToTempVar[param] = tempVar; + if (param->getOp() != kIROp_OutType) + { + builder.emitStore(tempVar, builder.emitLoad(param)); + } + param->replaceUsesWith(tempVar); + } + + for (auto block : func->getBlocks()) + { + for (auto inst : block->getChildren()) + { + if (inst->getOp() == kIROp_Return) + { + builder.setInsertBefore(inst); + for (auto& kv : mapParamToTempVar) + { + builder.emitStore(kv.Key, builder.emitLoad(kv.Value)); + } + } + } + } + } + + struct AutoDiffAddressConversionPolicy : public AddressConversionPolicy { DifferentiableTypeConformanceContext* diffTypeContext; @@ -512,6 +566,8 @@ namespace Slang IRCFGNormalizationPass cfgPass = {this->getSink()}; normalizeCFG(autoDiffSharedContext->sharedBuilder, func); + insertTempVarForMutableParams(sharedBuilder, func); + AutoDiffAddressConversionPolicy cvtPolicty; cvtPolicty.diffTypeContext = &diffTypeContext; auto result = eliminateAddressInsts(sharedBuilder, &cvtPolicty, func, sink); @@ -592,6 +648,23 @@ namespace Slang return fwdDiffFunc; } + InstPair BackwardDiffTranscriberBase::transcribeFuncParam(IRBuilder* builder, IRParam* origParam, IRInst* primalType) + { + SLANG_UNUSED(primalType); + + SLANG_RELEASE_ASSERT(origParam->getParent() && origParam->getParent()->getParent() + && origParam->getParent()->getParent()->getOp() == kIROp_Generic); + + auto primalInst = maybeCloneForPrimalInst(builder, origParam); + if (auto primalParam = as(primalInst)) + { + SLANG_RELEASE_ASSERT(builder->getInsertLoc().getBlock()); + primalParam->removeFromParent(); + builder->getInsertLoc().getBlock()->addParam(primalParam); + } + return InstPair(primalInst, nullptr); + } + // Transcribe a function definition. void BackwardDiffTranscriberBase::transcribeFuncImpl(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffPropagateFunc) { @@ -615,6 +688,8 @@ namespace Slang if (!fwdDiffFunc) return; + bool isResultDifferentiable = as(fwdDiffFunc->getResultType()); + // Split first block into a paramter block. this->makeParameterBlock(&tempBuilder, as(fwdDiffFunc)); @@ -642,12 +717,11 @@ namespace Slang } // Transpose the first block (parameter block) - transposeParameterBlock(builder, diffPropagateFunc); + List primalFuncSpecificParams; + auto dOutParameter = transposeParameterBlock(builder, diffPropagateFunc, primalFuncSpecificParams, isResultDifferentiable); builder->setInsertInto(diffPropagateFunc); - auto dOutParameter = diffPropagateFunc->getLastParam()->getPrevParam(); - // Transpose differential blocks from unzippedFwdDiffFunc into diffFunc (with dOutParameter) representing the DiffTransposePass::FuncTranspositionInfo info = {dOutParameter, nullptr}; diffTransposePass->transposeDiffBlocksInFunc(diffPropagateFunc, info); @@ -658,11 +732,32 @@ namespace Slang // with the intermediate results computed from the extracted func. IRInst* intermediateType = nullptr; auto extractedPrimalFunc = diffUnzipPass->extractPrimalFunc( - diffPropagateFunc, primalFunc, intermediateType); + diffPropagateFunc, primalFunc, isResultDifferentiable, intermediateType); // Clean up by deallocating the tempoarary forward derivative func. fwdDiffFunc->removeAndDeallocate(); + // Remove primalFuncSpecificParams. + for (auto specificParam : primalFuncSpecificParams) + { + while (auto use = specificParam->firstUse) + { + if (use->getUser()->getOp() == kIROp_Store && use == use->getUser()->getOperands()) + { + use->getUser()->removeAndDeallocate(); + } + else if (auto decor = as(use->getUser())) + { + decor->removeAndDeallocate(); + } + else + { + SLANG_UNEXPECTED("unexpected use of transcribed param."); + } + } + specificParam->removeAndDeallocate(); + } + // If primal function is nested in a generic, we want to create separate generics for all the associated things // we have just created. auto primalOuterGeneric = findOuterGeneric(primalFunc); @@ -689,9 +784,16 @@ namespace Slang auto specializedBackwardPrimalFunc = maybeSpecializeWithGeneric(*builder, primalFuncGeneric, primalOuterGeneric); builder->addBackwardDerivativePrimalDecoration(primalFunc, specializedBackwardPrimalFunc); } + + initializeLocalVariables(builder->getSharedBuilder(), primalFunc); + initializeLocalVariables(builder->getSharedBuilder(), diffPropagateFunc); } - void BackwardDiffTranscriberBase::transposeParameterBlock(IRBuilder* builder, IRFunc* diffFunc) + IRInst* BackwardDiffTranscriberBase::transposeParameterBlock( + IRBuilder* builder, + IRFunc* diffFunc, + List& primalFuncSpecificParams, + bool isResultDifferentiable) { IRBlock* fwdDiffParameterBlock = diffFunc->getFirstBlock(); @@ -699,7 +801,7 @@ namespace Slang auto fwdParamBlockBranch = as(fwdDiffParameterBlock->getTerminator()); auto nextBlock = fwdParamBlockBranch->getTargetBlock(); - builder->setInsertInto(fwdDiffParameterBlock); + builder->setInsertBefore(fwdParamBlockBranch); List fwdParams; for (auto child = fwdDiffParameterBlock->getFirstParam(); child; child = child->getNextParam()) @@ -710,8 +812,37 @@ namespace Slang // 1. Turn fwd-diff versions of the parameters into reverse-diff versions by wrapping them as InOutType<> for (auto fwdParam : fwdParams) { - // TODO: Handle ptr types. - if (auto diffPairType = as(fwdParam->getDataType())) + if (auto outType = as(fwdParam->getDataType())) + { + IRParam* newPropParam = nullptr; + IRParam* newPrimalParam = nullptr; + auto diffPairType = as(outType->getValueType()); + if (diffPairType) + { + // Create dOut param. + auto diffType = (IRType*)differentiableTypeConformanceContext.getDifferentialTypeFromDiffPairType(builder, diffPairType); + newPropParam = builder->emitParam(diffType); + newPrimalParam = builder->emitParam(builder->getOutType(diffPairType->getValueType())); + } + else + { + newPrimalParam = builder->emitParam(outType); + } + + // Create a temp var to represent the original `out` param. + auto arg = builder->emitVar(outType->getValueType()); + builder->addAutoDiffOriginalValueDecoration(arg, newPrimalParam); + if (newPropParam) + { + builder->addDecoration(arg, kIROp_OutParamReverseGradientDecoration, newPropParam); + } + + fwdParam->replaceUsesWith(arg); + fwdParam->removeAndDeallocate(); + + primalFuncSpecificParams.add(newPrimalParam); + } + else if (auto diffPairType = as(fwdParam->getDataType())) { // Create inout version. auto inoutDiffPairType = builder->getInOutType(diffPairType); @@ -725,7 +856,7 @@ namespace Slang } else { - // Default case (parameter has nothing to do with differentiation) + // Default case (parameter is inout type or has nothing to do with differentiation) // Simply move the parameter to the end. // fwdParam->removeFromParent(); @@ -735,236 +866,24 @@ namespace Slang auto paramCount = as(diffFunc->getDataType())->getParamCount(); - // 2. Add a parameter for 'derivative of the output' (d_out). + // 2. If the return type of the original function is differentiable, + // add a parameter for 'derivative of the output' (d_out). // The type is the second last parameter type of the function. // - auto dOutParamType = as(diffFunc->getDataType())->getParamType(paramCount - 2); - - SLANG_ASSERT(dOutParamType); - - builder->emitParam(dOutParamType); - - // Add a parameter for intermediate val. - builder->emitParam(as(diffFunc->getDataType())->getParamType(paramCount - 1)); - } - - IRInst* BackwardDiffTranscriberBase::copyParam(IRBuilder* builder, IRParam* origParam) - { - auto primalDataType = origParam->getDataType(); - - if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)primalDataType)) - { - auto inoutDiffPairType = builder->getPtrType(kIROp_InOutType, diffPairType); - IRInst* diffParam = builder->emitParam(inoutDiffPairType); - - auto diffPairVarName = makeDiffPairName(origParam); - if (diffPairVarName.getLength() > 0) - builder->addNameHintDecoration(diffParam, diffPairVarName.getUnownedSlice()); - - SLANG_ASSERT(diffParam); - auto paramValue = builder->emitLoad(diffParam); - auto primal = builder->emitDifferentialPairGetPrimal(paramValue); - orginalToTranscribed.Add(origParam, primal); - primalToDiffPair.Add(primal, diffParam); - - return diffParam; - } - - return maybeCloneForPrimalInst(builder, origParam); - } - - InstPair BackwardDiffTranscriberBase::copyBinaryArith(IRBuilder* builder, IRInst* origArith) - { - SLANG_ASSERT(origArith->getOperandCount() == 2); - - auto origLeft = origArith->getOperand(0); - auto origRight = origArith->getOperand(1); - - IRInst* primalLeft; - if (!orginalToTranscribed.TryGetValue(origLeft, primalLeft)) - { - primalLeft = origLeft; - } - IRInst* primalRight; - if (!orginalToTranscribed.TryGetValue(origRight, primalRight)) - { - primalRight = origRight; - } - - auto resultType = origArith->getDataType(); - IRInst* newInst; - switch (origArith->getOp()) - { - case kIROp_Add: - newInst = builder->emitAdd(resultType, primalLeft, primalRight); - break; - case kIROp_Mul: - newInst = builder->emitMul(resultType, primalLeft, primalRight); - break; - case kIROp_Sub: - newInst = builder->emitSub(resultType, primalLeft, primalRight); - break; - case kIROp_Div: - newInst = builder->emitDiv(resultType, primalLeft, primalRight); - break; - default: - newInst = nullptr; - getSink()->diagnose(origArith->sourceLoc, - Diagnostics::unimplemented, - "this arithmetic instruction cannot be differentiated"); - } - orginalToTranscribed.Add(origArith, newInst); - return InstPair(newInst, nullptr); - } - - IRInst* BackwardDiffTranscriberBase::transposeBinaryArithBackward(IRBuilder* builder, IRInst* origArith, IRInst* grad) - { - SLANG_ASSERT(origArith->getOperandCount() == 2); - - auto lhs = origArith->getOperand(0); - auto rhs = origArith->getOperand(1); - - if (as(lhs->getDataType())) - { - lhs = builder->emitLoad(lhs); - lhs = builder->emitDifferentialPairGetPrimal(lhs); - } - if (as(rhs->getDataType())) - { - rhs = builder->emitLoad(rhs); - rhs = builder->emitDifferentialPairGetPrimal(rhs); - } - - IRInst* leftGrad; - IRInst* rightGrad; - - - switch (origArith->getOp()) - { - case kIROp_Add: - leftGrad = grad; - rightGrad = grad; - break; - case kIROp_Mul: - leftGrad = builder->emitMul(grad->getDataType(), rhs, grad); - rightGrad = builder->emitMul(grad->getDataType(), lhs, grad); - break; - case kIROp_Sub: - leftGrad = grad; - rightGrad = builder->emitNeg(grad->getDataType(), grad); - break; - case kIROp_Div: - leftGrad = builder->emitMul(grad->getDataType(), rhs, grad); - rightGrad = builder->emitMul(grad->getDataType(), lhs, grad); // TODO 1.0 / Grad - break; - default: - getSink()->diagnose(origArith->sourceLoc, - Diagnostics::unimplemented, - "this arithmetic instruction cannot be differentiated"); - } - - lhs = origArith->getOperand(0); - rhs = origArith->getOperand(1); - if (auto leftGrads = upperGradients.TryGetValue(lhs)) - { - leftGrads->add(leftGrad); - } - else + IRParam* dOutParam = nullptr; + if (isResultDifferentiable) { - upperGradients.Add(lhs, leftGrad); - } - if (auto rightGrads = upperGradients.TryGetValue(rhs)) - { - rightGrads->add(rightGrad); - } - else - { - upperGradients.Add(rhs, rightGrad); - } + auto dOutParamType = as(diffFunc->getDataType())->getParamType(paramCount - 2); - return nullptr; - } + SLANG_ASSERT(dOutParamType); - InstPair BackwardDiffTranscriberBase::copyInst(IRBuilder* builder, IRInst* origInst) - { - // Handle common SSA-style operations - switch (origInst->getOp()) - { - case kIROp_Param: - return transcribeParam(builder, as(origInst)); - - case kIROp_Return: - return InstPair(nullptr, nullptr); - - case kIROp_Add: - case kIROp_Mul: - case kIROp_Sub: - case kIROp_Div: - return copyBinaryArith(builder, origInst); - - default: - // Not yet implemented - SLANG_ASSERT(0); + dOutParam = builder->emitParam(dOutParamType); } - return InstPair(nullptr, nullptr); - } - - IRInst* BackwardDiffTranscriberBase::transposeParamBackward(IRBuilder* builder, IRInst* param, IRInst* grad) - { - IRInOutType* inoutParam = as(param->getDataType()); - auto pairType = as(inoutParam->getValueType()); - auto paramValue = builder->emitLoad(param); - auto primal = builder->emitDifferentialPairGetPrimal(paramValue); - auto diff = builder->emitDifferentialPairGetDifferential( - (IRType*)pairBuilder->getDiffTypeFromPairType(builder, pairType), - paramValue - ); - auto newDiff = builder->emitAdd(grad->getDataType(), diff, grad); - auto updatedParam = builder->emitMakeDifferentialPair(pairType, primal, newDiff); - auto store = builder->emitStore(param, updatedParam); - - return store; - } - - IRInst* BackwardDiffTranscriberBase::transposeInstBackward(IRBuilder* builder, IRInst* origInst, IRInst* grad) - { - // Handle common SSA-style operations - switch (origInst->getOp()) - { - case kIROp_Param: - return transposeParamBackward(builder, as(origInst), grad); - - case kIROp_Add: - case kIROp_Mul: - case kIROp_Sub: - case kIROp_Div: - return transposeBinaryArithBackward(builder, origInst, grad); - - case kIROp_DifferentialPairGetPrimal: - { - if (auto param = primalToDiffPair.TryGetValue(origInst)) - { - if (auto leftGrads = upperGradients.TryGetValue(*param)) - { - leftGrads->add(grad); - } - else - { - upperGradients.Add(*param, grad); - } - } - else - SLANG_ASSERT(0); - return nullptr; - } - - default: - // Not yet implemented - SLANG_ASSERT(0); - } + // Add a parameter for intermediate val. + builder->emitParam(as(diffFunc->getDataType())->getParamType(paramCount - 1)); - return nullptr; + return dOutParam; } InstPair BackwardDiffTranscriberBase::transcribeSpecialize(IRBuilder* builder, IRSpecialize* origSpecialize) diff --git a/source/slang/slang-ir-autodiff-rev.h b/source/slang/slang-ir-autodiff-rev.h index f789089b0..617e6b79b 100644 --- a/source/slang/slang-ir-autodiff-rev.h +++ b/source/slang/slang-ir-autodiff-rev.h @@ -61,7 +61,8 @@ struct BackwardDiffTranscriberBase : AutoDiffTranscriberBase IRFuncType* differentiateFunctionTypeImpl(IRBuilder* builder, IRFuncType* funcType, IRInst* intermediateType); - InstPair transposeBlock(IRBuilder* builder, IRBlock* origBlock); + IRType* transcribeParamTypeForPrimalFunc(IRBuilder* builder, IRType* paramType); + IRType* transcribeParamTypeForPropagateFunc(IRBuilder* builder, IRType* paramType); // Puts parameters into their own block. void makeParameterBlock(IRBuilder* inBuilder, IRFunc* func); @@ -69,19 +70,10 @@ struct BackwardDiffTranscriberBase : AutoDiffTranscriberBase // Transcribe a function definition. virtual InstPair transcribeFunc(IRBuilder* builder, IRFunc* primalFunc, IRFunc* diffFunc) = 0; - void transposeParameterBlock(IRBuilder* builder, IRFunc* diffFunc); + // Transcribes the parameter block and returns the dOut param if exists. + IRInst* transposeParameterBlock(IRBuilder* builder, IRFunc* diffFunc, List& primalFuncSpecificParams, bool isResultDifferentiable); - IRInst* copyParam(IRBuilder* builder, IRParam* origParam); - - InstPair copyBinaryArith(IRBuilder* builder, IRInst* origArith); - - IRInst* transposeBinaryArithBackward(IRBuilder* builder, IRInst* origArith, IRInst* grad); - - InstPair copyInst(IRBuilder* builder, IRInst* origInst); - - IRInst* transposeParamBackward(IRBuilder* builder, IRInst* param, IRInst* grad); - - IRInst* transposeInstBackward(IRBuilder* builder, IRInst* origInst, IRInst* grad); + InstPair transcribeFuncParam(IRBuilder* builder, IRParam* origParam, IRInst* primalType); InstPair transcribeSpecialize(IRBuilder* builder, IRSpecialize* origSpecialize); diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index 520c6d276..8f21e8c62 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -17,16 +17,6 @@ DiagnosticSink* AutoDiffTranscriberBase::getSink() return sink; } -String AutoDiffTranscriberBase::makeDiffPairName(IRInst* origVar) -{ - if (auto namehintDecoration = origVar->findDecoration()) - { - return ("dp" + String(namehintDecoration->getName())); - } - - return String(""); -} - void AutoDiffTranscriberBase::mapDifferentialInst(IRInst* origInst, IRInst* diffInst) { if (hasDifferentialInst(origInst)) @@ -523,46 +513,7 @@ InstPair AutoDiffTranscriberBase::transcribeParam(IRBuilder* builder, IRParam* o bool isFuncParam = (func && origParam->getParent() == func->getFirstBlock()); if (isFuncParam) { - if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)primalDataType)) - { - IRInst* diffPairParam = builder->emitParam(diffPairType); - - auto diffPairVarName = makeDiffPairName(origParam); - if (diffPairVarName.getLength() > 0) - builder->addNameHintDecoration(diffPairParam, diffPairVarName.getUnownedSlice()); - - SLANG_ASSERT(diffPairParam); - - if (auto pairType = as(diffPairType)) - { - return InstPair( - builder->emitDifferentialPairGetPrimal(diffPairParam), - builder->emitDifferentialPairGetDifferential( - (IRType*)pairBuilder->getDiffTypeFromPairType(builder, pairType), - diffPairParam)); - } - else if (auto pairPtrType = as(diffPairType)) - { - auto ptrInnerPairType = as(pairPtrType->getValueType()); - - return InstPair( - builder->emitDifferentialPairAddressPrimal(diffPairParam), - builder->emitDifferentialPairAddressDifferential( - builder->getPtrType( - kIROp_PtrType, - (IRType*)pairBuilder->getDiffTypeFromPairType(builder, ptrInnerPairType)), - diffPairParam)); - } - } - - auto primalInst = cloneInst(&cloneEnv, builder, origParam); - if (auto primalParam = as(primalInst)) - { - SLANG_RELEASE_ASSERT(builder->getInsertLoc().getBlock()); - primalParam->removeFromParent(); - builder->getInsertLoc().getBlock()->addParam(primalParam); - } - return InstPair(primalInst, nullptr); + return transcribeFuncParam(builder, origParam, primalDataType); } else { @@ -617,10 +568,14 @@ IRInst* AutoDiffTranscriberBase::getDifferentialZeroOfType(IRBuilder* builder, I switch (diffType->getOp()) { case kIROp_DifferentialPairType: - return builder->emitMakeDifferentialPair( - diffType, - getDifferentialZeroOfType(builder, as(diffType)->getValueType()), - getDifferentialZeroOfType(builder, as(diffType)->getValueType())); + { + auto makeDiffPair = builder->emitMakeDifferentialPair( + diffType, + getDifferentialZeroOfType(builder, as(diffType)->getValueType()), + getDifferentialZeroOfType(builder, as(diffType)->getValueType())); + builder->markInstAsDifferential(makeDiffPair, as(diffType)->getValueType()); + return makeDiffPair; + } } if (auto arrayType = as(primalType)) @@ -647,6 +602,7 @@ IRInst* AutoDiffTranscriberBase::getDifferentialZeroOfType(IRBuilder* builder, I { auto wt = lookupInterface->getWitnessTable(); zeroMethod = builder->emitLookupInterfaceMethodInst(builder->getFuncType(List(), diffType), wt, autoDiffSharedContext->zeroMethodStructKey); + builder->markInstAsDifferential(zeroMethod); } } SLANG_RELEASE_ASSERT(zeroMethod); @@ -759,6 +715,7 @@ InstPair AutoDiffTranscriberBase::transcribeReturn(IRBuilder* builder, IRReturn* IRInst* primalReturnVal = findOrTranscribePrimalInst(builder, origReturnVal); IRInst* primalReturn = builder->emitReturn(primalReturnVal); + builder->markInstAsMixedDifferential(primalReturn, nullptr); return InstPair(primalReturn, nullptr); } diff --git a/source/slang/slang-ir-autodiff-transcriber-base.h b/source/slang/slang-ir-autodiff-transcriber-base.h index e6a525dee..208bfbc28 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.h +++ b/source/slang/slang-ir-autodiff-transcriber-base.h @@ -123,6 +123,8 @@ struct AutoDiffTranscriberBase InstPair transcribeParam(IRBuilder* builder, IRParam* origParam); + virtual InstPair transcribeFuncParam(IRBuilder* builder, IRParam* origParam, IRInst* primalType) = 0; + InstPair transcribeLookupInterfaceMethod(IRBuilder* builder, IRLookupWitnessMethod* lookupInst); InstPair transcribeBlockImpl(IRBuilder* builder, IRBlock* origBlock, HashSet& instsToSkip); diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h index 5aad6e3a3..2a341ed38 100644 --- a/source/slang/slang-ir-autodiff-transpose.h +++ b/source/slang/slang-ir-autodiff-transpose.h @@ -501,6 +501,10 @@ struct DiffTransposePass List workList; // Build initial list of blocks to process by checking if they're differential blocks. + List traverseWorkList; + HashSet traverseSet; + traverseWorkList.add(revDiffFunc->getFirstBlock()); + traverseSet.Add(revDiffFunc->getFirstBlock()); for (IRBlock* block = revDiffFunc->getFirstBlock(); block; block = block->getNextBlock()) { if (!isDifferentialInst(block)) @@ -534,10 +538,13 @@ struct DiffTransposePass for (auto block : workList) { // Set dOutParameter as the transpose gradient for the return inst, if any. - if (auto returnInst = as(block->getTerminator())) + if (transposeInfo.dOutInst) { - this->addRevGradientForFwdInst(returnInst, RevGradient(returnInst, transposeInfo.dOutInst, nullptr)); - retVal = returnInst->getVal(); + if (auto returnInst = as(block->getTerminator())) + { + this->addRevGradientForFwdInst(returnInst, RevGradient(returnInst, transposeInfo.dOutInst, nullptr)); + retVal = returnInst->getVal(); + } } IRBlock* revBlock = revBlockMap[block]; @@ -575,7 +582,7 @@ struct DiffTransposePass auto branch = subBuilder.emitBranch(firstRevBlock); - if (!retVal) + if (!retVal || retVal->getOp() == kIROp_VoidLit) { retVal = subBuilder.getVoidValue(); } @@ -849,6 +856,8 @@ struct DiffTransposePass { auto returnPairType = as( tryGetPrimalTypeFromDiffInst(returnInst->getVal())); + if (!returnPairType) + return; primalType = returnPairType->getValueType(); } else if (auto loadInst = as(inst)) @@ -955,21 +964,33 @@ struct DiffTransposePass { auto arg = fwdCall->getArg(ii); - // If this isn't a ptr-type, make a var. - if (!as(arg->getDataType()) && getDiffPairType(arg->getDataType())) + if (arg->getOp() == kIROp_LoadReverseGradient) { - auto pairType = as(arg->getDataType()); + // Original parameters that are `out DifferentiableType` will turn into + // a `in Differential` parameter. The split logic will insert LoadReverseGradient insts + // to inform us this case. Here we just need to generate a load of the derivative variable + // and use it as the final argument. + args.add(builder->emitLoad(arg->getOperand(0))); + } + else if (!as(arg->getDataType()) && getDiffPairType(arg->getDataType())) + { + // Normal differentiable input parameter will become an inout DiffPair parameter + // in the propagate func. The split logic has already prepared the initial value + // to pass in. We need to define a temp variable with this initial value and pass + // in the temp variable as argument to the inout parameter. - auto var = builder->emitVar(arg->getDataType()); + auto makePairArg = as(arg); + SLANG_RELEASE_ASSERT(makePairArg); - SLANG_ASSERT(as(arg)); + auto pairType = as(arg->getDataType()); + auto var = builder->emitVar(arg->getDataType()); // Initialize this var to (arg.primal, 0). builder->emitStore( - var, + var, builder->emitMakeDifferentialPair( arg->getDataType(), - as(arg)->getPrimalValue(), + makePairArg->getPrimalValue(), builder->emitCallInst( (IRType*)diffTypeContext.getDifferentialForType(builder, pairType->getValueType()), diffTypeContext.getZeroMethodForType(builder, pairType->getValueType()), @@ -987,9 +1008,12 @@ struct DiffTransposePass } } - args.add(revValue); - argTypes.add(revValue->getDataType()); - argRequiresLoad.add(false); + if (revValue) + { + args.add(revValue); + argTypes.add(revValue->getDataType()); + argRequiresLoad.add(false); + } args.add(primalContextDecor->getBackwardDerivativePrimalContextVar()); argTypes.add(builder->getOutType( @@ -1024,10 +1048,8 @@ struct DiffTransposePass gradients.add(RevGradient( RevGradient::Flavor::Simple, fwdCall->getArg(ii), - builder->emitLoad( - builder->emitDifferentialPairAddressDifferential( - diffArgPtrType, - args[ii])), + builder->emitDifferentialPairGetDifferential( + diffArgPtrType, builder->emitLoad(args[ii])), nullptr)); } } @@ -1213,6 +1235,8 @@ struct DiffTransposePass case kIROp_UpdateElement: return transposeUpdateElement(builder, fwdInst, revValue); + case kIROp_LoadReverseGradient: + case kIROp_DefaultConstruct: case kIROp_Specialize: case kIROp_unconditionalBranch: case kIROp_conditionalBranch: @@ -1266,8 +1290,8 @@ struct DiffTransposePass if (as(loadType)) { - auto primalPtr = builder->emitDifferentialPairAddressPrimal(revPtr); - auto primalVal = builder->emitLoad(primalPtr); + auto primalPairVal = builder->emitLoad(revPtr); + auto primalVal = builder->emitDifferentialPairGetPrimal(primalPairVal); auto pairVal = builder->emitMakeDifferentialPair(loadType, primalVal, aggregateGradient); @@ -1284,12 +1308,21 @@ struct DiffTransposePass TranspositionResult transposeStore(IRBuilder* builder, IRStore* fwdStore, IRInst*) { + IRInst* revVal = nullptr; + if (auto revGradDecor = fwdStore->getPtr()->findDecoration()) + { + revVal = revGradDecor->getValue(); + } + else + { + revVal = builder->emitLoad(fwdStore->getPtr()); + } return TranspositionResult( List( RevGradient( RevGradient::Flavor::Simple, fwdStore->getVal(), - builder->emitLoad(fwdStore->getPtr()), + revVal, fwdStore))); } diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp index daf6e44d4..378ea1cc2 100644 --- a/source/slang/slang-ir-autodiff-unzip.cpp +++ b/source/slang/slang-ir-autodiff-unzip.cpp @@ -279,7 +279,7 @@ struct ExtractPrimalFuncContext inst); } - IRFunc* turnUnzippedFuncIntoPrimalFunc(IRFunc* unzippedFunc, IRFunc* originalFunc, IRInst*& outIntermediateType) + IRFunc* turnUnzippedFuncIntoPrimalFunc(IRFunc* unzippedFunc, IRFunc* originalFunc, bool isResultDifferentiable, IRInst*& outIntermediateType) { IRBuilder builder(sharedBuilder); @@ -369,33 +369,59 @@ struct ExtractPrimalFuncContext builder.emitStore(outIntermediary, defVal); // The primal func will not have the result derivative param (second to last param), so we remove it. - auto resultDerivativeParam = func->getLastParam()->getPrevParam(); - SLANG_RELEASE_ASSERT(!resultDerivativeParam->hasUses()); - resultDerivativeParam->removeAndDeallocate(); + if (isResultDifferentiable) + { + auto resultDerivativeParam = func->getLastParam()->getPrevParam(); + SLANG_RELEASE_ASSERT(!resultDerivativeParam->hasUses()); + resultDerivativeParam->removeAndDeallocate(); + } - // Finally, go through parameters and turn DifferentiablePair back to T. - for (auto param : func->getParams()) + // Finally, go through parameters and translate their type back to primal type. + for (auto param = func->getFirstParam(); param;) { - IRInst* valueType = param->getDataType(); - auto inoutType = as(param->getDataType()); - if (inoutType) valueType = inoutType->getValueType(); - auto diffPairType = as(valueType); - if (!diffPairType) continue; - builder.setInsertBefore(firstBlock->getFirstOrdinaryInst()); - - auto originalValueType = diffPairType->getValueType(); - - // Create a local var to act as the old param. - auto tempVar = builder.emitVar(diffPairType); - param->replaceUsesWith(tempVar); - auto pairValue = builder.emitMakeDifferentialPair( - diffPairType, - param, - backwardPrimalTranscriber->getDifferentialZeroOfType(&builder, originalValueType)); - builder.emitStore(tempVar, pairValue); - - // Change the param type to original type. - param->setFullType(originalValueType); + auto next = param->getNextParam(); + [this, firstBlock, &builder, param]() + { + for (auto use = param->firstUse; use; use = use->nextUse) + { + if (use->getUser()->getOp() == kIROp_AutoDiffOriginalValueDecoration) + { + use->getUser()->getParent()->replaceUsesWith(param); + return; + } + else if (use->getUser()->getOp() == kIROp_OutParamReverseGradientDecoration) + { + // This is a propagate func specific parameter, we should remove it. + SLANG_RELEASE_ASSERT(!param->hasMoreThanOneUse()); + param->removeAndDeallocate(); + return; + } + } + + IRInst* valueType = param->getDataType(); + auto inoutType = as(param->getDataType()); + if (inoutType) valueType = inoutType->getValueType(); + auto diffPairType = as(valueType); + if (!diffPairType) + return; + + builder.setInsertBefore(firstBlock->getFirstOrdinaryInst()); + + auto originalValueType = diffPairType->getValueType(); + + // Create a local var to act as the old param. + auto tempVar = builder.emitVar(diffPairType); + param->replaceUsesWith(tempVar); + auto pairValue = builder.emitMakeDifferentialPair( + diffPairType, + param, + backwardPrimalTranscriber->getDifferentialZeroOfType(&builder, originalValueType)); + builder.emitStore(tempVar, pairValue); + + // Change the param type to original type. + param->setFullType(originalValueType); + }(); + param = next; } return unzippedFunc; @@ -420,7 +446,7 @@ static void copyPrimalValueStructKeyDecorations(IRInst* inst, IRCloneEnv& cloneE } IRFunc* DiffUnzipPass::extractPrimalFunc( - IRFunc* func, IRFunc* originalFunc, IRInst*& intermediateType) + IRFunc* func, IRFunc* originalFunc, bool isResultDifferentiable, IRInst*& intermediateType) { IRBuilder builder(this->autodiffContext->sharedBuilder); builder.setInsertBefore(func); @@ -434,7 +460,7 @@ IRFunc* DiffUnzipPass::extractPrimalFunc( context.init(autodiffContext->sharedBuilder, autodiffContext->transcriberSet.primalTranscriber); intermediateType = nullptr; - auto primalFunc = context.turnUnzippedFuncIntoPrimalFunc(clonedFunc, originalFunc, intermediateType); + auto primalFunc = context.turnUnzippedFuncIntoPrimalFunc(clonedFunc, originalFunc, isResultDifferentiable, intermediateType); if (auto nameHint = primalFunc->findDecoration()) { diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h index d808cbb5e..3055d057b 100644 --- a/source/slang/slang-ir-autodiff-unzip.h +++ b/source/slang/slang-ir-autodiff-unzip.h @@ -139,6 +139,13 @@ struct DiffUnzipPass builder->setInsertInto(unzippedFunc); + auto originalParam = func->getFirstParam(); + for (auto primalParam = unzippedFunc->getFirstParam(); primalParam; primalParam = primalParam->getNextParam()) + { + primalMap[originalParam] = primalParam; + originalParam = originalParam->getNextParam(); + } + // Functions need to have at least two blocks at this point (one for parameters, // and atleast one for code) // @@ -469,7 +476,7 @@ struct DiffUnzipPass } } - IRFunc* extractPrimalFunc(IRFunc* func, IRFunc* originalFunc, IRInst*& intermediateType); + IRFunc* extractPrimalFunc(IRFunc* func, IRFunc* originalFunc, bool isResultDifferentiable, IRInst*& intermediateType); bool isRelevantDifferentialPair(IRType* type) { @@ -537,7 +544,6 @@ struct DiffUnzipPass for (UIndex ii = 0; ii < mixedCall->getArgCount(); ii++) { auto arg = mixedCall->getArg(ii); - if (isRelevantDifferentialPair(arg->getDataType())) { primalArgs.add(lookupPrimalInst(arg)); @@ -552,20 +558,29 @@ struct DiffUnzipPass auto mixedDecoration = mixedCall->findDecoration(); SLANG_ASSERT(mixedDecoration); - auto fwdPairResultType = as(mixedDecoration->getPairType()); - SLANG_ASSERT(fwdPairResultType); - - auto primalType = fwdPairResultType->getValueType(); - auto diffType = (IRType*) diffTypeContext.getDifferentialForType(&globalBuilder, primalType); + IRType* primalType = mixedCall->getFullType(); + IRType* diffType = mixedCall->getFullType(); + IRType* resultType = mixedCall->getFullType(); + if (auto fwdPairResultType = as(mixedDecoration->getPairType())) + { + primalType = fwdPairResultType->getValueType(); + diffType = (IRType*)diffTypeContext.getDifferentialForType(&globalBuilder, primalType); + resultType = fwdPairResultType; + } auto primalVal = primalBuilder->emitCallInst(primalType, primalFn, primalArgs); primalBuilder->addBackwardDerivativePrimalContextDecoration(primalVal, intermediateVar); + SLANG_RELEASE_ASSERT(mixedCall->getArgCount() <= primalFuncType->getParamCount()); + List diffArgs; for (UIndex ii = 0; ii < mixedCall->getArgCount(); ii++) { auto arg = mixedCall->getArg(ii); + // Depending on the type and direction of each argument, + // we might need to prepare a different value for the transposition logic to produce the + // correct final argument in the propagate function call. if (isRelevantDifferentialPair(arg->getDataType())) { auto primalArg = lookupPrimalInst(arg); @@ -574,18 +589,45 @@ struct DiffUnzipPass // If arg is a mixed differential (pair), it should have already been split. SLANG_ASSERT(primalArg); SLANG_ASSERT(diffArg); - - auto pairArg = diffBuilder->emitMakeDifferentialPair( + auto primalParamType = primalFuncType->getParamType(ii); + + if (auto outType = as(primalParamType)) + { + // For `out` parameters that expects an input derivative to propagate through, + // we insert a `LoadReverseGradient` inst here to signify the logic in `transposeStore` + // that this argument should actually be the currently accumulated derivative on + // this variable. The end purpose is that we will generate a load(diffArg) in the + // final transposed code and use that as the argument for the call, but we can't just + // emit a normal load inst here because the transposition logic will turn loads into stores. + auto outDiffType = cast(diffArg->getDataType())->getValueType(); + auto gradArg = diffBuilder->emitLoadReverseGradient(outDiffType, diffArg); + diffBuilder->markInstAsDifferential(gradArg, primalArg->getDataType()); + diffArgs.add(gradArg); + } + else if (auto inoutType = as(primalParamType)) + { + SLANG_UNIMPLEMENTED_X("nested call inout parameter"); + } + else + { + // For ordinary differentiable input parameters, we make sure to provide + // a differential pair. The actual logic that generates an inout variable + // will be handled in `transposeCall()`. + auto pairArg = diffBuilder->emitMakeDifferentialPair( arg->getDataType(), primalArg, diffArg); - diffBuilder->markInstAsDifferential(pairArg, primalArg->getDataType()); - diffArgs.add(pairArg); + diffBuilder->markInstAsDifferential(pairArg, primalArg->getDataType()); + diffArgs.add(pairArg); + } } else { - diffArgs.add(arg); + // For non differentiable arguments, we can simply pass the argument as is + // if this isn't a `out` parameter, in which case it is removed from propagate call. + if (!as(arg->getDataType())) + diffArgs.add(arg); } } @@ -593,19 +635,22 @@ struct DiffUnzipPass diffBuilder->markInstAsDifferential(newFwdCallee); - auto diffPairVal = diffBuilder->emitCallInst( - fwdPairResultType, + auto callInst = diffBuilder->emitCallInst( + resultType, newFwdCallee, diffArgs); - diffBuilder->markInstAsDifferential(diffPairVal, primalType); + diffBuilder->markInstAsDifferential(callInst, primalType); disableIRValidationAtInsert(); - diffBuilder->addBackwardDerivativePrimalContextDecoration(diffPairVal, intermediateVar); + diffBuilder->addBackwardDerivativePrimalContextDecoration(callInst, intermediateVar); enableIRValidationAtInsert(); - auto diffVal = diffBuilder->emitDifferentialPairGetDifferential(diffType, diffPairVal); - diffBuilder->markInstAsDifferential(diffVal, primalType); - + IRInst* diffVal = nullptr; + if (as(callInst->getDataType())) + { + diffVal = diffBuilder->emitDifferentialPairGetDifferential(diffType, callInst); + diffBuilder->markInstAsDifferential(diffVal, primalType); + } return InstPair(primalVal, diffVal); } @@ -616,52 +661,92 @@ struct DiffUnzipPass InstPair splitLoad(IRBuilder* primalBuilder, IRBuilder* diffBuilder, IRLoad* mixedLoad) { - // By the nature of how diff pairs are used, and the fact that FieldAddress/GetElementPtr, - // etc, cannot appear before a GetDifferential/GetPrimal, a mixed load can only be from a - // parameter or a variable. - // - if (as(mixedLoad->getPtr())) + if (auto param = as(mixedLoad->getPtr())) { - // Should not occur with current impl of fwd-mode. - // If impl. changes, impl this case too. - // - SLANG_UNIMPLEMENTED_X("Splitting a load from a param is not currently implemented."); + auto diffPairPtrType = as(param->getFullType()); + SLANG_RELEASE_ASSERT(diffPairPtrType); + auto diffPairType = as(diffPairPtrType->getValueType()); + SLANG_RELEASE_ASSERT(diffPairType); + auto diffType = (IRType*)diffTypeContext.getDifferentialTypeFromDiffPairType(diffBuilder, diffPairType); + auto loadedParam = primalBuilder->emitLoad(param); + return InstPair( + primalBuilder->emitDifferentialPairGetPrimal(loadedParam), + primalBuilder->emitDifferentialPairGetDifferential(diffType, loadedParam)); } // Everything else should have already been split. auto primalPtr = lookupPrimalInst(mixedLoad->getPtr()); auto diffPtr = lookupDiffInst(mixedLoad->getPtr()); + auto primalVal = primalBuilder->emitLoad(primalPtr); + auto diffVal = diffBuilder->emitLoad(diffPtr); + diffBuilder->markInstAsDifferential(diffVal, primalVal->getFullType()); + return InstPair(primalVal, diffVal); + } + + InstPair splitStore(IRBuilder* primalBuilder, IRBuilder* diffBuilder, IRStore* mixedStore) + { + // We will only generate mixed store to parameters. + if (!as(mixedStore->getPtr())) + { + SLANG_UNIMPLEMENTED_X("Splitting a store that is not writing to a param."); + } + + auto primalAddr = mixedStore->getPtr(); + + auto primalVal = lookupPrimalInst(mixedStore->getVal()); + auto diffVal = lookupDiffInst(mixedStore->getVal()); - return InstPair(primalBuilder->emitLoad(primalPtr), diffBuilder->emitLoad(diffPtr)); + // For now the param type and value type will not type-check in these store insts, + // but the param inst will be changed to the correct type after we synthesize primal and + // propagate func. + auto primalStore = primalBuilder->emitStore(primalAddr, primalVal); + auto diffStore = diffBuilder->emitStore(primalAddr, diffVal); + + diffBuilder->markInstAsDifferential(diffStore, primalVal->getFullType()); + return InstPair(primalStore, diffStore); } InstPair splitVar(IRBuilder* primalBuilder, IRBuilder* diffBuilder, IRVar* mixedVar) { - auto pairType = as(mixedVar->getDataType()); + auto pairType = as(as(mixedVar->getDataType())->getValueType()); auto primalType = pairType->getValueType(); auto diffType = (IRType*) diffTypeContext.getDifferentialForType(primalBuilder, primalType); - - return InstPair(primalBuilder->emitVar(primalType), diffBuilder->emitVar(diffType)); + auto primalVar = primalBuilder->emitVar(primalType); + auto diffVar = diffBuilder->emitVar(diffType); + diffBuilder->markInstAsDifferential(diffVar, primalType); + return InstPair(primalVar, diffVar); } InstPair splitReturn(IRBuilder* primalBuilder, IRBuilder* diffBuilder, IRReturn* mixedReturn) { auto pairType = as(mixedReturn->getVal()->getDataType()); - auto primalType = pairType->getValueType(); + // Are we returning a differentiable value? + if (pairType) + { + auto primalType = pairType->getValueType(); - // Check that we have an unambiguous 'first' differential block. - SLANG_ASSERT(firstDiffBlock); - auto primalBranch = primalBuilder->emitBranch(firstDiffBlock); - auto pairVal = diffBuilder->emitMakeDifferentialPair( - pairType, - lookupPrimalInst(mixedReturn->getVal()), - lookupDiffInst(mixedReturn->getVal())); - diffBuilder->markInstAsDifferential(pairVal, primalType); + // Check that we have an unambiguous 'first' differential block. + SLANG_ASSERT(firstDiffBlock); + auto primalBranch = primalBuilder->emitBranch(firstDiffBlock); + auto pairVal = diffBuilder->emitMakeDifferentialPair( + pairType, + lookupPrimalInst(mixedReturn->getVal()), + lookupDiffInst(mixedReturn->getVal())); + diffBuilder->markInstAsDifferential(pairVal, primalType); - auto returnInst = diffBuilder->emitReturn(pairVal); - diffBuilder->markInstAsDifferential(returnInst, primalType); + auto returnInst = diffBuilder->emitReturn(pairVal); + diffBuilder->markInstAsDifferential(returnInst, primalType); - return InstPair(primalBranch, returnInst); + return InstPair(primalBranch, returnInst); + } + else + { + // If return value is not differentiable, just turn it into a trivial branch. + auto primalBranch = primalBuilder->emitBranch(firstDiffBlock); + auto returnInst = diffBuilder->emitReturn(); + diffBuilder->markInstAsDifferential(returnInst, nullptr); + return InstPair(primalBranch, returnInst); + } } bool isBlockIndexed(IRBlock* block) @@ -973,6 +1058,9 @@ struct DiffUnzipPass case kIROp_Load: return splitLoad(primalBuilder, diffBuilder, as(inst)); + case kIROp_Store: + return splitStore(primalBuilder, diffBuilder, as(inst)); + case kIROp_Return: return splitReturn(primalBuilder, diffBuilder, as(inst)); diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index 8952f9756..7a2e8c75e 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -24,7 +24,7 @@ bool isBackwardDifferentiableFunc(IRInst* func) return false; } -static IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementKey) +IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementKey) { if (auto witnessTable = as(witness)) { @@ -400,6 +400,14 @@ IRInst* DifferentiableTypeConformanceContext::lookUpInterfaceMethod(IRBuilder* b return nullptr; } +IRInst* DifferentiableTypeConformanceContext::getDifferentialTypeFromDiffPairType( + IRBuilder* builder, IRDifferentialPairType* diffPairType) +{ + auto witness = diffPairType->getWitness(); + SLANG_RELEASE_ASSERT(witness); + return _lookupWitness(builder, witness, sharedContext->differentialAssocTypeStructKey); +} + void DifferentiableTypeConformanceContext::buildGlobalWitnessDictionary() { for (auto globalInst : sharedContext->moduleInst->getChildren()) diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h index 2258ff753..30f053673 100644 --- a/source/slang/slang-ir-autodiff.h +++ b/source/slang/slang-ir-autodiff.h @@ -160,6 +160,8 @@ struct DifferentiableTypeConformanceContext IRInst* lookUpConformanceForType(IRInst* type); IRInst* lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key); + + IRInst* getDifferentialTypeFromDiffPairType(IRBuilder* builder, IRDifferentialPairType* diffPairType); // Lookup and return the 'Differential' type declared in the concrete type // in order to conform to the IDifferentiable interface. diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp index ce3e563f5..e7d5a0e5c 100644 --- a/source/slang/slang-ir-check-differentiability.cpp +++ b/source/slang/slang-ir-check-differentiability.cpp @@ -5,28 +5,6 @@ namespace Slang { -IRInst* getSpecializedVal(IRInst* inst) -{ - int loopLimit = 1024; - while (inst && inst->getOp() == kIROp_Specialize) - { - inst = as(inst)->getBase(); - loopLimit--; - if (loopLimit == 0) - return inst; - } - return inst; -} - -IRInst* getLeafFunc(IRInst* func) -{ - func = getSpecializedVal(func); - if (!func) - return nullptr; - if (auto genericFunc = as(func)) - return findInnerMostGenericReturnVal(genericFunc); - return func; -} struct CheckDifferentiabilityPassContext : public InstPassBase { @@ -47,7 +25,7 @@ public: bool _isFuncMarkedForAutoDiff(IRInst* func) { - func = getLeafFunc(func); + func = getResolvedInstForDecorations(func); if (!func) return false; for (auto decorations : func->getDecorations()) @@ -65,7 +43,7 @@ public: bool _isDifferentiableFuncImpl(IRInst* func, DifferentiableLevel level) { - func = getLeafFunc(func); + func = getResolvedInstForDecorations(func); if (!func) return false; @@ -103,7 +81,7 @@ public: } } - func = getLeafFunc(func); + func = getResolvedInstForDecorations(func); if (!func) return false; @@ -332,7 +310,7 @@ public: sink->diagnose( inst, Diagnostics::lossOfDerivativeDueToCallOfNonDifferentiableFunction, - getLeafFunc(call->getCallee()), + getResolvedInstForDecorations(call->getCallee()), requiredDiffLevel == DifferentiableLevel::Forward ? "forward" : "backward"); } } diff --git a/source/slang/slang-ir-init-local-var.cpp b/source/slang/slang-ir-init-local-var.cpp new file mode 100644 index 000000000..4b28db268 --- /dev/null +++ b/source/slang/slang-ir-init-local-var.cpp @@ -0,0 +1,34 @@ +// slang-ir-init-local-var.cpp +#include "slang-ir-init-local-var.h" +#include "slang-ir.h" +#include "slang-ir-insts.h" + +namespace Slang +{ + +void initializeLocalVariables(SharedIRBuilder* sharedBuilder, IRGlobalValueWithCode* func) +{ + IRBuilder builder(sharedBuilder); + for (auto block : func->getBlocks()) + { + for (auto inst : block->getChildren()) + { + if (inst->getOp() == kIROp_Var) + { + auto firstUse = inst->firstUse; + bool initialized = + (firstUse && firstUse->getUser()->getOp() == kIROp_Store && + firstUse->getUser()->getParent() == inst->getParent()); + if (initialized) + continue; + builder.setInsertAfter(inst); + builder.emitStore( + inst, + builder.emitDefaultConstruct( + as(inst->getFullType())->getValueType())); + } + } + } +} + +} // namespace Slang diff --git a/source/slang/slang-ir-init-local-var.h b/source/slang/slang-ir-init-local-var.h new file mode 100644 index 000000000..ad06684fc --- /dev/null +++ b/source/slang/slang-ir-init-local-var.h @@ -0,0 +1,14 @@ +// slang-ir-init-local-var.h +#pragma once + +namespace Slang +{ + struct IRModule; + struct IRGlobalValueWithCode; + struct SharedIRBuilder; + + // Init local variables with default values if the variable isn't being initialized locally in + // the same basic block. + void initializeLocalVariables(SharedIRBuilder* sharedBuilder, IRGlobalValueWithCode* func); + +} diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index f2294671e..e1143b7b9 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -326,6 +326,10 @@ INST(Var, var, 0, 0) INST(Load, load, 1, 0) INST(Store, store, 2, 0) +// Produced and removed during backward auto-diff pass as a temporary placeholder representing the +// currently accumulated derivative to pass to some dOut argument in a nested call. +INST(LoadReverseGradient, LoadReverseGradient, 1, 0) + INST(FieldExtract, get_field, 2, 0) INST(FieldAddress, get_field_addr, 2, 0) @@ -767,6 +771,12 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) /// forward-differentiated updateElement inst. INST(PrimalElementTypeDecoration, primalElementType, 1, 0) + /// Used by the auto-diff pass. An `out T` parameter will transcribe to a `in T.Differential` parameter. + /// We will also create a temp var of type `T.Differential` in the function body so the `load` and `stores` + /// can operand on a valid address. We use this decoration to associate this temp var with its corresponding + /// input parameter. + INST(OutParamReverseGradientDecoration, outParamRevGrad, 1, 0) + /// Used by the auto-diff pass to hold a reference to a /// differential member of a type in its associated differential type. INST(DerivativeMemberDecoration, derivativeMemberDecoration, 1, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index aca832c0c..132a96f16 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -723,6 +723,18 @@ struct IRMixedDifferentialInstDecoration : IRDecoration IRType* getPairType() { return as(getOperand(0)); } }; +struct IROutParamReverseGradientDecoration : IRDecoration +{ + enum + { + kOp = kIROp_OutParamReverseGradientDecoration + }; + + IR_LEAF_ISA(OutParamReverseGradientDecoration) + + IRInst* getValue() { return getOperand(0); } +}; + struct IRBackwardDifferentiableDecoration : IRDecoration { enum @@ -1770,6 +1782,12 @@ struct IRGetElementPtr : IRInst IRInst* getIndex() { return getOperand(1); } }; +struct IRLoadReverseGradient :IRInst +{ + IR_LEAF_ISA(LoadReverseGradient) + IRInst* getValue() { return getOperand(0); } +}; + struct IRGetNativePtr : IRInst { IR_LEAF_ISA(GetNativePtr); @@ -2598,7 +2616,6 @@ public: IRInst* getBoolValue(bool value); IRInst* getIntValue(IRType* type, IRIntegerValue value); IRInst* getFloatValue(IRType* type, IRFloatingPointValue value); - IRInst* getDifferentialBottom(); IRStringLit* getStringValue(const UnownedStringSlice& slice); IRPtrLit* _getPtrValue(void* ptr); IRPtrLit* getNullPtrValue(IRType* type); @@ -2920,8 +2937,6 @@ public: IRInst* emitMakeOptionalNone(IRInst* optType, IRInst* defaultValue); IRInst* emitDifferentialPairGetDifferential(IRType* diffType, IRInst* diffPair); IRInst* emitDifferentialPairGetPrimal(IRInst* diffPair); - IRInst* emitDifferentialPairAddressDifferential(IRType* diffType, IRInst* diffPair); - IRInst* emitDifferentialPairAddressPrimal(IRInst* diffPair); IRInst* emitMakeVector( IRType* type, UInt argCount, @@ -3129,6 +3144,8 @@ public: IRInst* emitLoad( IRInst* ptr); + IRInst* emitLoadReverseGradient(IRType* type, IRInst* diffValue); + IRInst* emitStore( IRInst* dstPtr, IRInst* srcVal); diff --git a/source/slang/slang-ir-util.cpp b/source/slang/slang-ir-util.cpp index 3ffbb75f7..5cf074484 100644 --- a/source/slang/slang-ir-util.cpp +++ b/source/slang/slang-ir-util.cpp @@ -223,7 +223,9 @@ String dumpIRToString(IRInst* root) { StringBuilder sb; StringWriter writer(&sb, Slang::WriterFlag::AutoFlush); - dumpIR(root, IRDumpOptions(), nullptr, &writer); + IRDumpOptions options = {}; + options.flags = IRDumpOptions::Flag::DumpDebugIds; + dumpIR(root, options, nullptr, &writer); return sb.ToString(); } diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 2a4ae59a7..4814726cf 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -3745,15 +3745,7 @@ namespace Slang IRInst* IRBuilder::emitDifferentialPairGetDifferential(IRType* diffType, IRInst* diffPair) { - return emitIntrinsicInst( - diffType, - kIROp_DifferentialPairGetDifferential, - 1, - &diffPair); - } - - IRInst* IRBuilder::emitDifferentialPairAddressDifferential(IRType* diffType, IRInst* diffPair) - { + SLANG_ASSERT(as(diffPair->getDataType())); return emitIntrinsicInst( diffType, kIROp_DifferentialPairGetDifferential, @@ -3763,7 +3755,7 @@ namespace Slang IRInst* IRBuilder::emitDifferentialPairGetPrimal(IRInst* diffPair) { - auto valueType = as(diffPair->getDataType())->getValueType(); + auto valueType = cast(diffPair->getDataType())->getValueType(); return emitIntrinsicInst( valueType, kIROp_DifferentialPairGetPrimal, @@ -3771,16 +3763,6 @@ namespace Slang &diffPair); } - IRInst* IRBuilder::emitDifferentialPairAddressPrimal(IRInst* diffPair) - { - auto valueType = as( - as(diffPair->getDataType())->getValueType())->getValueType(); - return emitIntrinsicInst( - this->getPtrType(kIROp_PtrType, valueType), - kIROp_DifferentialPairGetPrimal, - 1, - &diffPair); - } IRInst* IRBuilder::emitMakeMatrix( IRType* type, @@ -4240,6 +4222,18 @@ namespace Slang return inst; } + IRInst* IRBuilder::emitLoadReverseGradient(IRType* type, IRInst* diffValue) + { + auto inst = createInst( + this, + kIROp_LoadReverseGradient, + type, + diffValue); + + addInst(inst); + return inst; + } + IRInst* IRBuilder::emitLoad( IRType* type, IRInst* ptr) @@ -6818,6 +6812,7 @@ namespace Slang case kIROp_MakeTuple: case kIROp_GetTupleElement: case kIROp_Load: // We are ignoring the possibility of loads from bad addresses, or `volatile` loads + case kIROp_LoadReverseGradient: case kIROp_ImageSubscript: case kIROp_FieldExtract: case kIROp_FieldAddress: -- cgit v1.2.3