diff options
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.cpp | 30 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-propagate.h | 5 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-rev.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-transpose.h | 281 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-unzip.h | 244 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 4 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 24 | ||||
| -rw-r--r-- | tests/autodiff/reverse-nested-calls.slang | 29 | ||||
| -rw-r--r-- | tests/autodiff/reverse-nested-calls.slang.expected.txt | 6 | ||||
| -rw-r--r-- | tests/autodiff/reverse-struct-types.slang | 23 | ||||
| -rw-r--r-- | tests/autodiff/reverse-struct-types.slang.expected.txt | 2 |
11 files changed, 519 insertions, 131 deletions
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index 60c2721c7..d1e9f91ec 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -615,7 +615,12 @@ InstPair ForwardDerivativeTranscriber::transcribeLoad(IRBuilder* builder, IRLoad // Special case load from an `out` param, which will not have corresponding `diff` and // `primal` insts yet. + // TODO: Could we move this load to _after_ DifferentialPairGetPrimal, + // and DifferentialPairGetDifferential? + // auto load = builder->emitLoad(primalPtr); + builder->markInstAsMixedDifferential(load, diffPairType); + auto primalElement = builder->emitDifferentialPairGetPrimal(load); auto diffElement = builder->emitDifferentialPairGetDifferential( (IRType*)pairBuilder->getDiffTypeFromPairType(builder, diffPairType), load); @@ -647,7 +652,7 @@ InstPair ForwardDerivativeTranscriber::transcribeStore(IRBuilder* builder, IRSto if (auto diffPairType = as<IRDifferentialPairType>(primalLocationPtrType->getValueType())) { auto valToStore = builder->emitMakeDifferentialPair(diffPairType, primalStoreVal, diffStoreVal); - builder->markInstAsDifferential(diffStoreVal, diffPairType); + builder->markInstAsMixedDifferential(diffStoreVal, diffPairType); auto store = builder->emitStore(primalStoreLocation, valToStore); return InstPair(store, nullptr); @@ -690,6 +695,7 @@ InstPair ForwardDerivativeTranscriber::transcribeReturn(IRBuilder* builder, IRRe // Neither of these should be nullptr. SLANG_RELEASE_ASSERT(primalReturnVal && diffReturnVal); IRReturn* diffReturn = as<IRReturn>(builder->emitReturn(diffReturnVal)); + builder->markInstAsMixedDifferential(diffReturn, nullptr); return InstPair(diffReturn, diffReturn); } @@ -704,9 +710,11 @@ InstPair ForwardDerivativeTranscriber::transcribeReturn(IRBuilder* builder, IRRe SLANG_RELEASE_ASSERT(diffReturnVal); auto diffPair = builder->emitMakeDifferentialPair(pairType, primalReturnVal, diffReturnVal); - builder->markInstAsDifferential(diffPair, pairType); + builder->markInstAsMixedDifferential(diffPair, pairType); IRReturn* pairReturn = as<IRReturn>(builder->emitReturn(diffPair)); + builder->markInstAsMixedDifferential(pairReturn, pairType); + return InstPair(pairReturn, pairReturn); } else @@ -804,7 +812,8 @@ InstPair ForwardDerivativeTranscriber::transcribeCall(IRBuilder* builder, IRCall // If the user has already provided an differentiated implementation, use that. diffCallee = derivativeReferenceDecor->getForwardDerivativeFunc(); } - else if (primalCallee->findDecoration<IRForwardDifferentiableDecoration>()) + else if (primalCallee->findDecoration<IRForwardDifferentiableDecoration>() || + primalCallee->findDecoration<IRBackwardDifferentiableDecoration>()) { // If the function is marked for auto-diff, push a `differentiate` inst for a follow up pass // to generate the implementation. @@ -851,7 +860,7 @@ InstPair ForwardDerivativeTranscriber::transcribeCall(IRBuilder* builder, IRCall SLANG_RELEASE_ASSERT(diffArg); auto diffPair = builder->emitMakeDifferentialPair(pairType, primalArg, diffArg); - builder->markInstAsDifferential(diffPair, pairType); + builder->markInstAsMixedDifferential(diffPair, pairType); args.add(diffPair); continue; @@ -875,7 +884,7 @@ InstPair ForwardDerivativeTranscriber::transcribeCall(IRBuilder* builder, IRCall diffReturnType, diffCallee, args); - builder->markInstAsDifferential(callInst, origCall->getFullType()); + builder->markInstAsMixedDifferential(callInst, diffReturnType); if (diffReturnType->getOp() != kIROp_VoidType) { @@ -1578,8 +1587,15 @@ IRInst* ForwardDerivativeTranscriber::transcribe(IRBuilder* builder, IRInst* ori builder->addNameHintDecoration(pair.differential, sb.getUnownedSlice()); } - // Tag the differential inst using a decoration. - builder->markInstAsDifferential(pair.differential, as<IRType>(pair.primal->getDataType())); + // Tag the differential inst using a decoration (if it doesn't have one) + if (!pair.differential->findDecoration<IRDifferentialInstDecoration>() && + !pair.differential->findDecoration<IRMixedDifferentialInstDecoration>()) + { + // TODO: If the type is a 'relevant' pair type, need to mark it as mixed differential + // instead. + // + builder->markInstAsDifferential(pair.differential, as<IRType>(pair.primal->getDataType())); + } break; } diff --git a/source/slang/slang-ir-autodiff-propagate.h b/source/slang/slang-ir-autodiff-propagate.h index 9518ccb34..0d5686899 100644 --- a/source/slang/slang-ir-autodiff-propagate.h +++ b/source/slang/slang-ir-autodiff-propagate.h @@ -15,6 +15,11 @@ bool isDifferentialInst(IRInst* inst) return inst->findDecoration<IRDifferentialInstDecoration>(); } +bool isMixedDifferentialInst(IRInst* inst) +{ + return inst->findDecoration<IRMixedDifferentialInstDecoration>(); +} + struct DiffPropagationPass : InstPassBase { AutoDiffSharedContext* autodiffContext; diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index 34a08ee93..c7fbc415a 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -589,7 +589,7 @@ struct BackwardDiffTranscriber { // Create inout version. auto inoutDiffPairType = builder->getInOutType(diffPairType); - auto newParam = builder->emitParam(inoutDiffPairType); + auto newParam = builder->emitParam(inoutDiffPairType); // Map the _load_ of the new parameter as the clone of the old one. auto newParamLoad = builder->emitLoad(newParam); diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h index 75491d753..a14ecad84 100644 --- a/source/slang/slang-ir-autodiff-transpose.h +++ b/source/slang/slang-ir-autodiff-transpose.h @@ -150,17 +150,25 @@ struct DiffTransposePass // Insert after the last block. builder.setInsertInto(revBlock); - List<IRInst*> ptrInsts; + // Move pointer & reference insts to the top of the reverse-mode block. + List<IRInst*> nonValueInsts; for (IRInst* child = fwdBlock->getFirstOrdinaryInst(); child; child = child->getNextInst()) { - // If the instruction is pointer typed, move to top of new reverse-mode block + // If the instruction is pointer typed, it's not actually computing a value. + // if (as<IRPtrTypeBase>(child->getDataType())) - ptrInsts.add(child); + nonValueInsts.add(child); + + // Slang doesn't support function values. So if we see a func-typed inst + // it's proabably a reference to a function. + // + if (as<IRFuncType>(child->getDataType())) + nonValueInsts.add(child); } - for (auto ptrInst : ptrInsts) + for (auto inst : nonValueInsts) { - ptrInst->insertAtEnd(revBlock); + inst->insertAtEnd(revBlock); } @@ -210,34 +218,6 @@ struct DiffTransposePass if (hasRevGradients(inst)) gradients = popRevGradients(inst); - // Are we dealing with DifferentialPairType? - if (as<IRDifferentialPairType>(inst->getDataType())) - { - // This will be a 'hybrid' primal-differential inst, - // so we add a pair (primal_value, 0) as an additional - // gradient to represent the primal part of the computation. - // - // Now, if the unzip pass has done it's job, the _only_ - // case should be that inst is IRMakeDifferentialPair - // - SLANG_ASSERT(as<IRMakeDifferentialPair>(inst)); - auto primalType = as<IRDifferentialPairType>(inst->getDataType())->getValueType(); - auto diffType = (IRType*)pairBuilder.getDiffTypeFromPairType(builder, as<IRDifferentialPairType>(inst->getDataType())); - - auto primalInst = as<IRMakeDifferentialPair>(inst)->getPrimalValue(); - auto zeroMethod = diffTypeContext.getZeroMethodForType(builder, primalType); - - // Must exist. - SLANG_ASSERT(zeroMethod); - auto diffInst = builder->emitCallInst(diffType, zeroMethod, List<IRInst*>()); - - gradients.add( - RevGradient( - inst, - builder->emitMakeDifferentialPair(inst->getDataType(), primalInst, diffInst), - nullptr)); - } - IRType* primalType = tryGetPrimalTypeFromDiffInst(inst); if (!primalType) @@ -249,6 +229,14 @@ struct DiffTransposePass tryGetPrimalTypeFromDiffInst(returnInst->getVal())); primalType = returnPairType->getValueType(); } + else if (auto loadInst = as<IRLoad>(inst)) + { + // TODO: Unzip loads properly to avoid having to side-step this check for IRLoad + if (auto pairType = as<IRDifferentialPairType>(loadInst->getDataType())) + { + primalType = pairType->getValueType(); + } + } } if (!primalType) @@ -278,6 +266,116 @@ struct DiffTransposePass addRevGradientForFwdInst(gradient.targetInst, gradient); } } + + TranspositionResult transposeCall(IRBuilder* builder, IRCall* fwdCall, IRInst* revValue) + { + auto fwdDiffCallee = as<IRForwardDifferentiate>(fwdCall->getCallee()); + + // If the callee is not a fwd-differentiate(fn), then there's only two + // cases. This is a call to something that doesn't need to be transposed + // or this is a user-written function calling something that isn't marked + // with IRForwardDifferentiate, but is handling differentials. + // We currently do not handle the latter. + // However, if we see a callee with no parameters, we can just skip over. + // since there's nothing to backpropagate to. + // + if (!fwdDiffCallee) + { + if (fwdCall->getArgCount() == 0) + { + return TranspositionResult(List<RevGradient>()); + } + else + { + SLANG_UNIMPLEMENTED_X( + "This case should only trigger on a user-defined fwd-mode function" + " calling another user-defined function not marked with __fwd_diff()"); + } + } + + auto baseFn = fwdDiffCallee->getBaseFn(); + + List<IRInst*> args; + List<IRType*> argTypes; + List<bool> isArgPtrTyped; + + for (UIndex ii = 0; ii < fwdCall->getArgCount(); ii++) + { + auto arg = fwdCall->getArg(ii); + + // If this isn't a ptr-type, make a var. + if (!as<IRPtrTypeBase>(arg->getDataType()) && as<IRDifferentialPairType>(arg->getDataType())) + { + auto pairType = as<IRDifferentialPairType>(arg->getDataType()); + + auto var = builder->emitVar(arg->getDataType()); + + SLANG_ASSERT(as<IRMakeDifferentialPair>(arg)); + + // Initialize this var to (arg.primal, 0). + builder->emitStore( + var, + builder->emitMakeDifferentialPair( + arg->getDataType(), + as<IRMakeDifferentialPair>(arg)->getPrimalValue(), + builder->emitCallInst( + (IRType*)diffTypeContext.getDifferentialForType(builder, pairType->getValueType()), + diffTypeContext.getZeroMethodForType(builder, pairType->getValueType()), + List<IRInst*>()))); + + args.add(var); + argTypes.add(builder->getInOutType(pairType)); + isArgPtrTyped.add(false); + } + else + { + args.add(arg); + argTypes.add(arg->getDataType()); + isArgPtrTyped.add(true); + } + } + + args.add(revValue); + argTypes.add(revValue->getDataType()); + + auto revFnType = builder->getFuncType(argTypes, builder->getVoidType()); + auto revCallee = builder->emitBackwardDifferentiateInst( + revFnType, + baseFn); + + builder->emitCallInst(revFnType->getResultType(), revCallee, args); + + List<RevGradient> gradients; + for (UIndex ii = 0; ii < fwdCall->getArgCount(); ii++) + { + // Is this arg relevant to auto-diff? + if (as<IRDifferentialPairType>(as<IRPtrTypeBase>(args[ii]->getDataType())->getValueType())) + { + // If this is ptr typed, ignore (the gradient will be accumulated on the pointer) + // automatically. + // + if (!isArgPtrTyped[ii]) + { + auto diffArgType = (IRType*)diffTypeContext.getDifferentialForType( + builder, + as<IRDifferentialPairType>( + as<IRPtrTypeBase>(argTypes[ii])->getValueType())->getValueType()); + auto diffArgPtrType = builder->getPtrType(kIROp_PtrType, diffArgType); + + gradients.add(RevGradient( + RevGradient::Flavor::Simple, + fwdCall->getArg(ii), + builder->emitLoad( + builder->emitDifferentialPairAddressDifferential( + diffArgPtrType, + args[ii])), + nullptr)); + } + } + } + + return TranspositionResult(gradients); + } TranspositionResult transposeInst(IRBuilder* builder, IRInst* fwdInst, IRInst* revValue) { @@ -288,6 +386,9 @@ struct DiffTransposePass case kIROp_Mul: case kIROp_Sub: return transposeArithmetic(builder, fwdInst, revValue); + + case kIROp_Call: + return transposeCall(builder, as<IRCall>(fwdInst), revValue); case kIROp_swizzle: return transposeSwizzle(builder, as<IRSwizzle>(fwdInst), revValue); @@ -322,35 +423,49 @@ struct DiffTransposePass { auto revPtr = fwdLoad->getPtr(); + auto primalType = tryGetPrimalTypeFromDiffInst(fwdLoad); + auto loadType = fwdLoad->getDataType(); + + List<RevGradient> gradients(RevGradient( + revPtr, + revValue, + nullptr)); + if (usedPtrs.contains(revPtr)) { // Re-emit a load to get the _current_ value of revPtr. auto revCurrGrad = builder->emitLoad(revPtr); // Add the current value to the aggregation list. - List<RevGradient> gradients( - RevGradient( - revCurrGrad, - revValue, - nullptr), - RevGradient( - revCurrGrad, - revCurrGrad, - nullptr)); - - auto primalType = tryGetPrimalTypeFromDiffInst(fwdLoad); - // Get the _total_ value. - auto aggregateGradient = emitAggregateValue(builder, primalType, gradients); - - // Store this back into the pointer. - builder->emitStore(revPtr, aggregateGradient); + gradients.add(RevGradient( + revPtr, + revCurrGrad, + nullptr)); } else { usedPtrs.add(revPtr); + } + + // Get the _total_ value. + auto aggregateGradient = emitAggregateValue( + builder, + primalType, + gradients); + + if (as<IRDifferentialPairType>(loadType)) + { + auto primalPtr = builder->emitDifferentialPairAddressPrimal(revPtr); + auto primalVal = builder->emitLoad(primalPtr); + + auto pairVal = builder->emitMakeDifferentialPair(loadType, primalVal, aggregateGradient); - // Store into pointer - builder->emitStore(revPtr, revValue); + builder->emitStore(revPtr, pairVal); + } + else + { + // Store this back into the pointer. + builder->emitStore(revPtr, aggregateGradient); } return TranspositionResult(List<RevGradient>()); @@ -359,7 +474,6 @@ struct DiffTransposePass TranspositionResult transposeStore(IRBuilder* builder, IRStore* fwdStore, IRInst*) { - // (A = p.x) -> (p = float3(dA, 0, 0)) return TranspositionResult( List<RevGradient>( RevGradient( @@ -384,7 +498,6 @@ struct DiffTransposePass TranspositionResult transposeFieldExtract(IRBuilder*, IRFieldExtract* fwdExtract, IRInst* revValue) { - // (A = p.x) -> (p = float3(dA, 0, 0)) return TranspositionResult( List<RevGradient>( RevGradient( @@ -394,17 +507,19 @@ struct DiffTransposePass fwdExtract))); } - TranspositionResult transposeMakePair(IRBuilder* builder, IRMakeDifferentialPair* fwdMakePair, IRInst* revValue) + TranspositionResult transposeMakePair(IRBuilder*, IRMakeDifferentialPair* fwdMakePair, IRInst* revValue) { + // Even though makePair returns a pair of (primal, differential) + // revValue will only contain the reverse-value for 'differential' + // // (P = (A, dA)) -> (dA += dP) + // return TranspositionResult( List<RevGradient>( RevGradient( RevGradient::Flavor::Simple, fwdMakePair->getDifferentialValue(), - builder->emitDifferentialPairGetDifferential( - fwdMakePair->getDifferentialValue()->getDataType(), - revValue), + revValue, fwdMakePair))); } @@ -414,7 +529,7 @@ struct DiffTransposePass return TranspositionResult( List<RevGradient>( RevGradient( - RevGradient::Flavor::GetDifferential, + RevGradient::Flavor::Simple, fwdGetDiff->getBase(), revValue, fwdGetDiff))); @@ -448,39 +563,7 @@ struct DiffTransposePass // void accumulateGradientsForLoad(IRBuilder* builder, IRLoad* revLoad) { - auto revPtr = revLoad->getPtr(); - - // Assert that ptr type is of the form IRPtrTypeBase<IRDifferentialPairType<T>> - SLANG_ASSERT(as<IRPtrTypeBase>(revPtr->getDataType())); - SLANG_ASSERT(as<IRPtrTypeBase>(revPtr->getDataType())->getValueType()->getOp() == kIROp_DifferentialPairType); - - auto paramPairType = as<IRDifferentialPairType>(as<IRPtrTypeBase>(revPtr->getDataType())->getValueType()); - - // Gather gradients. - auto gradients = popRevGradients(revLoad); - if (gradients.getCount() == 0) - { - // Ignore. - return; - } - else - { - // Re-emit a load to get the _current_ value of revPtr. - auto revCurrGrad = builder->emitLoad(revPtr); - - // Add the current value to the aggregation list. - gradients.add( - RevGradient( - revLoad, - revCurrGrad, - nullptr)); - - // Get the _total_ value. - auto aggregateGradient = emitAggregateValue(builder, paramPairType, gradients); - - // Store this back into the pointer. - builder->emitStore(revPtr, aggregateGradient); - } + return transposeInst(builder, revLoad); } TranspositionResult transposeReturn(IRBuilder*, IRReturn* fwdReturn, IRInst* revValue) @@ -488,16 +571,14 @@ struct DiffTransposePass // TODO: This check needs to be changed to something like: isRelevantDifferentialPair() if (as<IRDifferentialPairType>(fwdReturn->getVal()->getDataType())) { - // This is a subtle case, even though the returned value is returning - // a pair, we need to pretend that the primal value is not being returned - // since we only care about transposing differential computation. - // So we're going to assume there is an implicit GetDifferential() - // around the return value before returning. + // Simply pass on the gradient to the previous inst. + // (Even if the return value is pair typed, we only care about the differential part) + // So this will remain a 'simple' gradient. // return TranspositionResult( List<RevGradient>( RevGradient( - RevGradient::Flavor::GetDifferential, + RevGradient::Flavor::Simple, fwdReturn->getVal(), revValue, fwdReturn))); @@ -856,6 +937,8 @@ struct DiffTransposePass IRInst* emitAggregateDifferentialPair(IRBuilder* builder, IRType* aggPrimalType, List<RevGradient> pairGradients) { + SLANG_UNEXPECTED("Should not run."); + auto aggPairType = as<IRDifferentialPairType>(aggPrimalType); SLANG_ASSERT(aggPairType); @@ -923,7 +1006,9 @@ struct DiffTransposePass // a differential pair is really a 'hybrid' primal-differential type. // if (as<IRDifferentialPairType>(aggPrimalType)) - return emitAggregateDifferentialPair(builder, aggPrimalType, gradients); + { + SLANG_UNEXPECTED("Should not occur"); + } // Process non-simple gradients into simple gradients. // TODO: This is where we can improve efficiency later. diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h index 79dec365c..2bfe972ec 100644 --- a/source/slang/slang-ir-autodiff-unzip.h +++ b/source/slang/slang-ir-autodiff-unzip.h @@ -17,12 +17,32 @@ struct DiffUnzipPass IRCloneEnv cloneEnv; + DifferentiableTypeConformanceContext diffTypeContext; + + // Maps used to keep track of primal and + // differential versions of split insts. + // + Dictionary<IRInst*, IRInst*> primalMap; + Dictionary<IRInst*, IRInst*> diffMap; + DiffUnzipPass(AutoDiffSharedContext* autodiffContext) : - autodiffContext(autodiffContext) + autodiffContext(autodiffContext), diffTypeContext(autodiffContext) { } + IRInst* lookupPrimalInst(IRInst* inst) + { + return primalMap[inst]; + } + + IRInst* lookupDiffInst(IRInst* inst) + { + return diffMap[inst]; + } + IRFunc* unzipDiffInsts(IRFunc* func) { + diffTypeContext.setFunc(func); + IRBuilder builderStorage; builderStorage.init(autodiffContext->sharedBuilder); @@ -66,6 +86,185 @@ struct DiffUnzipPass return unzippedFunc; } + bool isRelevantDifferentialPair(IRType* type) + { + if (as<IRDifferentialPairType>(type)) + { + return true; + } + else if (auto argPtrType = as<IRPtrTypeBase>(type)) + { + if (as<IRDifferentialPairType>(argPtrType->getValueType())) + { + return true; + } + } + + return false; + } + + InstPair splitCall(IRBuilder* primalBuilder, IRBuilder* diffBuilder, IRCall* mixedCall) + { + IRBuilder globalBuilder; + globalBuilder.init(autodiffContext->sharedBuilder); + + auto fwdCallee = as<IRForwardDifferentiate>(mixedCall->getCallee()); + auto fwdCalleeType = as<IRFuncType>(fwdCallee->getDataType()); + auto baseFn = fwdCallee->getBaseFn(); + + List<IRInst*> primalArgs; + for (UIndex ii = 0; ii < mixedCall->getArgCount(); ii++) + { + auto arg = mixedCall->getArg(0); + + if (isRelevantDifferentialPair(arg->getDataType())) + { + primalArgs.add(lookupPrimalInst(arg)); + } + else + { + primalArgs.add(arg); + } + } + + auto mixedDecoration = mixedCall->findDecoration<IRMixedDifferentialInstDecoration>(); + SLANG_ASSERT(mixedDecoration); + + auto fwdPairResultType = as<IRDifferentialPairType>(mixedDecoration->getPairType()); + SLANG_ASSERT(fwdPairResultType); + + auto primalType = fwdPairResultType->getValueType(); + auto diffType = (IRType*) diffTypeContext.getDifferentialForType(&globalBuilder, primalType); + + auto primalVal = primalBuilder->emitCallInst(primalType, baseFn, primalArgs); + + List<IRInst*> diffArgs; + for (UIndex ii = 0; ii < mixedCall->getArgCount(); ii++) + { + auto arg = mixedCall->getArg(0); + + if (isRelevantDifferentialPair(arg->getDataType())) + { + auto primalArg = lookupPrimalInst(arg); + auto diffArg = lookupDiffInst(arg); + + // If arg is a mixed differential (pair), it should have already been split. + SLANG_ASSERT(primalArg); + SLANG_ASSERT(diffArg); + + auto pairArg = diffBuilder->emitMakeDifferentialPair( + arg->getDataType(), + primalArg, + diffArg); + + diffBuilder->markInstAsDifferential(pairArg, primalArg->getDataType()); + diffArgs.add(pairArg); + } + else + { + diffArgs.add(arg); + } + } + + auto newFwdCallee = diffBuilder->emitForwardDifferentiateInst(fwdCalleeType, baseFn); + diffBuilder->markInstAsDifferential(newFwdCallee); + + auto diffPairVal = diffBuilder->emitCallInst( + fwdPairResultType, + newFwdCallee, + diffArgs); + diffBuilder->markInstAsDifferential(diffPairVal, primalType); + + auto diffVal = diffBuilder->emitDifferentialPairGetDifferential(diffType, diffPairVal); + diffBuilder->markInstAsDifferential(diffVal, primalType); + + return InstPair(primalVal, diffVal); + } + + InstPair splitMakePair(IRBuilder*, IRBuilder*, IRMakeDifferentialPair* mixedPair) + { + return InstPair(mixedPair->getPrimalValue(), mixedPair->getDifferentialValue()); + } + + InstPair splitLoad(IRBuilder* primalBuilder, IRBuilder* diffBuilder, IRLoad* mixedLoad) + { + // By the nature of how diff pairs are used, and the fact that FieldAddress/GetElementPtr, + // etc, cannot appear before a GetDifferential/GetPrimal, a mixed load can only be from a + // parameter or a variable. + // + if (as<IRParam>(mixedLoad->getPtr())) + { + // Should not occur with current impl of fwd-mode. + // If impl. changes, impl this case too. + // + SLANG_UNIMPLEMENTED_X("Splitting a load from a param is not currently implemented."); + } + + // Everything else should have already been split. + auto primalPtr = lookupPrimalInst(mixedLoad->getPtr()); + auto diffPtr = lookupDiffInst(mixedLoad->getPtr()); + + return InstPair(primalBuilder->emitLoad(primalPtr), diffBuilder->emitLoad(diffPtr)); + } + + InstPair splitVar(IRBuilder* primalBuilder, IRBuilder* diffBuilder, IRVar* mixedVar) + { + auto pairType = as<IRDifferentialPairType>(mixedVar->getDataType()); + auto primalType = pairType->getValueType(); + auto diffType = (IRType*) diffTypeContext.getDifferentialForType(primalBuilder, primalType); + + return InstPair(primalBuilder->emitVar(primalType), diffBuilder->emitVar(diffType)); + } + + InstPair splitReturn(IRBuilder*, IRBuilder* diffBuilder, IRReturn* mixedReturn) + { + auto pairType = as<IRDifferentialPairType>(mixedReturn->getVal()->getDataType()); + auto primalType = pairType->getValueType(); + + auto pairVal = diffBuilder->emitMakeDifferentialPair( + pairType, + lookupPrimalInst(mixedReturn->getVal()), + lookupDiffInst(mixedReturn->getVal())); + diffBuilder->markInstAsDifferential(pairVal, primalType); + + auto returnInst = diffBuilder->emitReturn(pairVal); + diffBuilder->markInstAsDifferential(returnInst, primalType); + + return InstPair(nullptr, returnInst); + } + + InstPair _splitMixedInst(IRBuilder* primalBuilder, IRBuilder* diffBuilder, IRInst* inst) + { + switch (inst->getOp()) + { + case kIROp_Call: + return splitCall(primalBuilder, diffBuilder, as<IRCall>(inst)); + + case kIROp_Var: + return splitVar(primalBuilder, diffBuilder, as<IRVar>(inst)); + + case kIROp_MakeDifferentialPair: + return splitMakePair(primalBuilder, diffBuilder, as<IRMakeDifferentialPair>(inst)); + + case kIROp_Load: + return splitLoad(primalBuilder, diffBuilder, as<IRLoad>(inst)); + + case kIROp_Return: + return splitReturn(primalBuilder, diffBuilder, as<IRReturn>(inst)); + + default: + SLANG_ASSERT_FAILURE("Unhandled mixed diff inst"); + } + } + + void splitMixedInst(IRBuilder* primalBuilder, IRBuilder* diffBuilder, IRInst* inst) + { + auto instPair = _splitMixedInst(primalBuilder, diffBuilder, inst); + + primalMap[inst] = instPair.primal; + diffMap[inst] = instPair.differential; + } + void splitBlock(IRBlock* mainBlock, IRBlock* primalBlock, IRBlock* diffBlock) { // Make two builders for primal and differential blocks. @@ -77,14 +276,42 @@ struct DiffUnzipPass diffBuilder.init(autodiffContext->sharedBuilder); diffBuilder.setInsertInto(diffBlock); + List<IRInst*> splitInsts; for (auto child = mainBlock->getFirstChild(); child;) { IRInst* nextChild = child->getNextInst(); - if (isDifferentialInst(child) || as<IRTerminatorInst>(child)) + if (auto getDiffInst = as<IRDifferentialPairGetDifferential>(child)) + { + if (diffMap.ContainsKey(getDiffInst->getBase())) + { + getDiffInst->replaceUsesWith(lookupDiffInst(getDiffInst->getBase())); + getDiffInst->removeAndDeallocate(); + child = nextChild; + continue; + } + } + + if (auto getPrimalInst = as<IRDifferentialPairGetPrimal>(child)) + { + if (primalMap.ContainsKey(getPrimalInst->getBase())) + { + getPrimalInst->replaceUsesWith(lookupPrimalInst(getPrimalInst->getBase())); + getPrimalInst->removeAndDeallocate(); + child = nextChild; + continue; + } + } + + if (isDifferentialInst(child)) { child->insertAtEnd(diffBlock); } + else if (isMixedDifferentialInst(child)) + { + splitMixedInst(&primalBuilder, &diffBuilder, child); + splitInsts.add(child); + } else { child->insertAtEnd(primalBlock); @@ -93,6 +320,19 @@ struct DiffUnzipPass child = nextChild; } + // Remove insts that were split. + for (auto inst : splitInsts) + { + // Consistency check. + for (auto use = inst->firstUse; use; use = use->nextUse) + { + SLANG_RELEASE_ASSERT((use->getUser()->getParent() != primalBlock) && + (use->getUser()->getParent() != diffBlock)); + } + + inst->removeAndDeallocate(); + } + // Nothing should be left in the original block. SLANG_ASSERT(mainBlock->getFirstChild() == nullptr); diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 5784f60cb..c74388406 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -737,6 +737,10 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) /// a differential value. INST(DifferentialInstDecoration, diffInstDecoration, 1, 0) + /// Used by the auto-diff pass to mark insts that compute + /// BOTH a differential and a primal value. + INST(MixedDifferentialInstDecoration, mixedDiffInstDecoration, 1, 0) + /// Used by the auto-diff pass to hold a reference to a /// differential member of a type in its associated differential type. INST(DerivativeMemberDecoration, derivativeMemberDecoration, 1, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 1ef0fa4f8..67f17f5b2 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -610,6 +610,20 @@ struct IRDifferentialInstDecoration : IRDecoration IRType* getPrimalType() { return as<IRType>(getOperand(0)); } }; + +struct IRMixedDifferentialInstDecoration : IRDecoration +{ + enum + { + kOp = kIROp_MixedDifferentialInstDecoration + }; + + IRUse pairType; + IR_LEAF_ISA(MixedDifferentialInstDecoration) + + IRType* getPairType() { return as<IRType>(getOperand(0)); } +}; + struct IRBackwardDifferentiableDecoration : IRDecoration { enum @@ -3377,6 +3391,16 @@ public: addDecoration(value, kIROp_DifferentialInstDecoration, nullptr); } + void markInstAsMixedDifferential(IRInst* value) + { + addDecoration(value, kIROp_MixedDifferentialInstDecoration, nullptr); + } + + void markInstAsMixedDifferential(IRInst* value, IRType* pairType) + { + addDecoration(value, kIROp_MixedDifferentialInstDecoration, pairType); + } + void markInstAsDifferential(IRInst* value, IRType* primalType) { addDecoration(value, kIROp_DifferentialInstDecoration, primalType); diff --git a/tests/autodiff/reverse-nested-calls.slang b/tests/autodiff/reverse-nested-calls.slang new file mode 100644 index 000000000..2b55efd60 --- /dev/null +++ b/tests/autodiff/reverse-nested-calls.slang @@ -0,0 +1,29 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +typedef DifferentialPair<float> dpfloat; +typedef float.Differential dfloat; + +[BackwardDifferentiable] +float g(float y) +{ + return 4.0f * y; +} + +[BackwardDifferentiable] +float f(float x) +{ + return 3.0f * g(2.0f * x); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + dpfloat dpa = dpfloat(1.0, 0.0); + + __bwd_diff(f)(dpa, 1.0f); + outputBuffer[0] = dpa.d; // Expect: 24.0 +} diff --git a/tests/autodiff/reverse-nested-calls.slang.expected.txt b/tests/autodiff/reverse-nested-calls.slang.expected.txt new file mode 100644 index 000000000..0a39c4da6 --- /dev/null +++ b/tests/autodiff/reverse-nested-calls.slang.expected.txt @@ -0,0 +1,6 @@ +type: float +24.000000 +0.000000 +0.000000 +0.000000 +0.000000 diff --git a/tests/autodiff/reverse-struct-types.slang b/tests/autodiff/reverse-struct-types.slang index 699e50480..d2b52a008 100644 --- a/tests/autodiff/reverse-struct-types.slang +++ b/tests/autodiff/reverse-struct-types.slang @@ -9,27 +9,6 @@ struct A : IDifferentiable { float x; float y; - - [__unsafeForceInlineEarly] - static Differential dzero() - { - Differential b = {0.0, float.dzero()}; - return b; - } - - [__unsafeForceInlineEarly] - static Differential dadd(Differential a, Differential b) - { - Differential o = {a.x + b.x, 0.0}; - return o; - } - - [__unsafeForceInlineEarly] - static Differential dmul(This a, Differential b) - { - Differential o = {a.x * b.x, 0.0}; - return o; - } }; typedef DifferentialPair<A> dpA; @@ -56,7 +35,7 @@ void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) A.Differential dout = {1.0, 1.0}; __bwd_diff(f)(dpa, dout); - outputBuffer[0] = dpa.d.x; // Expect: 10 + outputBuffer[0] = dpa.d.x; // Expect: 7 outputBuffer[1] = dpa.d.y; // Expect: 0 } } diff --git a/tests/autodiff/reverse-struct-types.slang.expected.txt b/tests/autodiff/reverse-struct-types.slang.expected.txt index 82bc8f733..b94f4fec6 100644 --- a/tests/autodiff/reverse-struct-types.slang.expected.txt +++ b/tests/autodiff/reverse-struct-types.slang.expected.txt @@ -1,5 +1,5 @@ type: float -5.000000 +7.000000 0.000000 0.000000 0.000000 |
