diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2022-12-08 11:50:55 -0500 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-12-08 08:50:55 -0800 |
| commit | 468bb7ecf65c000c308adae511bf65a1ca4cc412 (patch) | |
| tree | 8042aaa77224d00f14a7267564ce7452ad6de67e /source | |
| parent | 53e891eb28ceac5f956399c65f2ae27d37f3d724 (diff) | |
More type support for reverse-mode (#2551)
* Add vector arithmetic test. Make gradient accumulation work for any IRLoad
* Added support for general vector types, and split transposition into transpose & materialize to allow emitting the fully accumulated gradient for complex types.
* Several bug fixes + finished up support for vector & struct types + removed prop pass
* minor fixes (int/uint casts)
* Removed IRConstruct
* Added some type casts to prevent warnings
* minor fix for unused variable
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 4 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.cpp | 84 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-rev.cpp | 5 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-transpose.h | 882 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-unzip.h | 8 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 11 |
7 files changed, 849 insertions, 147 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 336682bf4..1ffc45fbd 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -2136,9 +2136,9 @@ namespace Slang type->paramTypes.add(derivType); } } - + // Last parameter is the initial derivative of the original return type - type->paramTypes.add(originalType->resultType); + type->paramTypes.add(getDifferentialType(m_astBuilder, originalType->resultType, SourceLoc())); return type; } diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index d2d9a0e7d..60c2721c7 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -529,25 +529,54 @@ InstPair ForwardDerivativeTranscriber::transcribeBinaryArith(IRBuilder* builder, diffRight = diffRight ? diffRight : getDifferentialZeroOfType(builder, primalRight->getDataType()); auto resultType = primalArith->getDataType(); + auto diffType = (IRType*) differentiableTypeConformanceContext.getDifferentialForType(builder, resultType); + switch(origArith->getOp()) { case kIROp_Add: - return InstPair(primalArith, builder->emitAdd(resultType, diffLeft, diffRight)); + { + auto diffAdd = builder->emitAdd(diffType, diffLeft, diffRight); + builder->markInstAsDifferential(diffAdd, resultType); + + return InstPair(primalArith, diffAdd); + } + case kIROp_Mul: - return InstPair(primalArith, builder->emitAdd(resultType, - builder->emitMul(resultType, diffLeft, primalRight), - builder->emitMul(resultType, primalLeft, diffRight))); + { + auto diffLeftTimesRight = builder->emitMul(diffType, diffLeft, primalRight); + auto diffRightTimesLeft = builder->emitMul(diffType, primalLeft, diffRight); + builder->markInstAsDifferential(diffLeftTimesRight, resultType); + builder->markInstAsDifferential(diffRightTimesLeft, resultType); + + auto diffAdd = builder->emitAdd(diffType, diffLeftTimesRight, diffRightTimesLeft); + builder->markInstAsDifferential(diffAdd, resultType); + + return InstPair(primalArith, diffAdd); + } + case kIROp_Sub: - return InstPair(primalArith, builder->emitSub(resultType, diffLeft, diffRight)); + { + auto diffSub = builder->emitSub(diffType, diffLeft, diffRight); + builder->markInstAsDifferential(diffSub, resultType); + + return InstPair(primalArith, diffSub); + } case kIROp_Div: - return InstPair(primalArith, builder->emitDiv(resultType, - builder->emitSub( - resultType, - builder->emitMul(resultType, diffLeft, primalRight), - builder->emitMul(resultType, primalLeft, diffRight)), - builder->emitMul( - primalRight->getDataType(), primalRight, primalRight - ))); + { + auto diffLeftTimesRight = builder->emitMul(diffType, diffLeft, primalRight); + auto diffRightTimesLeft = builder->emitMul(diffType, primalLeft, diffRight); + auto diffSub = builder->emitSub(diffType, diffLeftTimesRight, diffRightTimesLeft); + builder->markInstAsDifferential(diffLeftTimesRight, resultType); + builder->markInstAsDifferential(diffRightTimesLeft, resultType); + builder->markInstAsDifferential(diffSub, resultType); + + auto diffMul = builder->emitMul(resultType, primalRight, primalRight); + + auto diffDiv = builder->emitDiv(diffType, diffSub, diffMul); + builder->markInstAsDifferential(diffDiv, resultType); + + return InstPair(primalArith, diffDiv); + } default: getSink()->diagnose(origArith->sourceLoc, Diagnostics::unimplemented, @@ -558,7 +587,6 @@ InstPair ForwardDerivativeTranscriber::transcribeBinaryArith(IRBuilder* builder, return InstPair(primalArith, nullptr); } - InstPair ForwardDerivativeTranscriber::transcribeBinaryLogic(IRBuilder* builder, IRInst* origLogic) { SLANG_ASSERT(origLogic->getOperandCount() == 2); @@ -619,6 +647,8 @@ InstPair ForwardDerivativeTranscriber::transcribeStore(IRBuilder* builder, IRSto if (auto diffPairType = as<IRDifferentialPairType>(primalLocationPtrType->getValueType())) { auto valToStore = builder->emitMakeDifferentialPair(diffPairType, primalStoreVal, diffStoreVal); + builder->markInstAsDifferential(diffStoreVal, diffPairType); + auto store = builder->emitStore(primalStoreLocation, valToStore); return InstPair(store, nullptr); } @@ -674,6 +704,8 @@ InstPair ForwardDerivativeTranscriber::transcribeReturn(IRBuilder* builder, IRRe SLANG_RELEASE_ASSERT(diffReturnVal); auto diffPair = builder->emitMakeDifferentialPair(pairType, primalReturnVal, diffReturnVal); + builder->markInstAsDifferential(diffPair, pairType); + IRReturn* pairReturn = as<IRReturn>(builder->emitReturn(diffPair)); return InstPair(pairReturn, pairReturn); } @@ -817,7 +849,10 @@ InstPair ForwardDerivativeTranscriber::transcribeCall(IRBuilder* builder, IRCall // If a pair type can be formed, this must be non-null. SLANG_RELEASE_ASSERT(diffArg); + auto diffPair = builder->emitMakeDifferentialPair(pairType, primalArg, diffArg); + builder->markInstAsDifferential(diffPair, pairType); + args.add(diffPair); continue; } @@ -826,7 +861,7 @@ InstPair ForwardDerivativeTranscriber::transcribeCall(IRBuilder* builder, IRCall // Add original/primal argument. args.add(primalArg); } - + IRType* diffReturnType = nullptr; diffReturnType = tryGetDiffPairType(builder, origCall->getFullType()); @@ -840,6 +875,7 @@ InstPair ForwardDerivativeTranscriber::transcribeCall(IRBuilder* builder, IRCall diffReturnType, diffCallee, args); + builder->markInstAsDifferential(callInst, origCall->getFullType()); if (diffReturnType->getOp() != kIROp_VoidType) { @@ -1145,7 +1181,11 @@ IRInst* ForwardDerivativeTranscriber::getDifferentialZeroOfType(IRBuilder* build SLANG_RELEASE_ASSERT(zeroMethod); auto emptyArgList = List<IRInst*>(); - return builder->emitCallInst((IRType*)diffType, zeroMethod, emptyArgList); + + auto callInst = builder->emitCallInst((IRType*)diffType, zeroMethod, emptyArgList); + builder->markInstAsDifferential(callInst, primalType); + + return callInst; } else { @@ -1489,10 +1529,10 @@ InstPair ForwardDerivativeTranscriber::transcribeGeneric(IRBuilder* inBuilder, I diffGeneric->setFullType(diffType); - // Transcribe children from origFunc into diffFunc. - builder.setInsertInto(diffGeneric); - for (auto block = origGeneric->getFirstBlock(); block; block = block->getNextBlock()) - this->transcribe(&builder, block); + // Transcribe children from origFunc into diffFunc. + builder.setInsertInto(diffGeneric); + for (auto block = origGeneric->getFirstBlock(); block; block = block->getNextBlock()) + this->transcribe(&builder, block); return InstPair(primalGeneric, diffGeneric); } @@ -1537,6 +1577,10 @@ IRInst* ForwardDerivativeTranscriber::transcribe(IRBuilder* builder, IRInst* ori sb << "s_diff_" << primalNameHint->getName(); builder->addNameHintDecoration(pair.differential, sb.getUnownedSlice()); } + + // Tag the differential inst using a decoration. + builder->markInstAsDifferential(pair.differential, as<IRType>(pair.primal->getDataType())); + break; } } diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index 8ec8f581c..34a08ee93 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -511,7 +511,10 @@ struct BackwardDiffTranscriber this->makeParameterBlock(builder, as<IRFunc>(fwdDiffFunc)); // This steps adds a decoration to instructions that are computing the differential. - diffPropagationPass->propagateDiffInstDecoration(builder, fwdDiffFunc); + // TODO: This is disabled for now because fwd-mode already adds differential decorations + // wherever need. We need to run this pass only for user-writted forward derivativecode. + // + // diffPropagationPass->propagateDiffInstDecoration(builder, fwdDiffFunc); // Copy primal insts to the first block of the unzipped function, copy diff insts to the // second block of the unzipped function. diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h index 659131820..75491d753 100644 --- a/source/slang/slang-ir-autodiff-transpose.h +++ b/source/slang/slang-ir-autodiff-transpose.h @@ -13,39 +13,61 @@ namespace Slang struct DiffTransposePass { - AutoDiffSharedContext* autodiffContext; - - DifferentialPairTypeBuilder pairBuilder; + + struct RevGradient + { + enum Flavor + { + Simple, + Swizzle, + GetElement, + GetDifferential, + FieldExtract, - Dictionary<IRInst*, List<IRInst*>> assignmentsMap; + Invalid + }; - Dictionary<IRInst*, IRInst*>* primalsMap; + RevGradient() : + flavor(Flavor::Invalid), targetInst(nullptr), revGradInst(nullptr), fwdGradInst(nullptr) + { } + + RevGradient(Flavor flavor, IRInst* targetInst, IRInst* revGradInst, IRInst* fwdGradInst) : + flavor(flavor), targetInst(targetInst), revGradInst(revGradInst), fwdGradInst(fwdGradInst) + { } - DiffTransposePass(AutoDiffSharedContext* autodiffContext) : - autodiffContext(autodiffContext), pairBuilder(autodiffContext) - { } + RevGradient(IRInst* targetInst, IRInst* revGradInst, IRInst* fwdGradInst) : + flavor(Flavor::Simple), targetInst(targetInst), revGradInst(revGradInst), fwdGradInst(fwdGradInst) + { } - struct RevAssignment - { - IRInst* lvalue; - IRInst* rvalue; + bool operator==(const RevGradient& other) const + { + return (other.targetInst == targetInst) && + (other.revGradInst == revGradInst) && + (other.fwdGradInst == fwdGradInst) && + (other.flavor == flavor); + } + + IRInst* targetInst; + IRInst* revGradInst; + IRInst* fwdGradInst; - RevAssignment(IRInst* lvalue, IRInst* rvalue) : lvalue(lvalue), rvalue(rvalue) - { } - RevAssignment() : lvalue(nullptr), rvalue(nullptr) - { } + Flavor flavor; }; + DiffTransposePass(AutoDiffSharedContext* autodiffContext) : + autodiffContext(autodiffContext), pairBuilder(autodiffContext), diffTypeContext(autodiffContext) + { } + struct TranspositionResult { // Holds a set of pairs of // (original-inst, inst-to-accumulate-for-orig-inst) - List<RevAssignment> revPairs; + List<RevGradient> revPairs; TranspositionResult() { } - TranspositionResult(List<RevAssignment> revPairs) : revPairs(revPairs) + TranspositionResult(List<RevGradient> revPairs) : revPairs(revPairs) { } }; @@ -64,9 +86,10 @@ struct DiffTransposePass void transposeDiffBlocksInFunc( IRFunc* revDiffFunc, - // TODO: Maybe there's a more elegant way to pass this information. FuncTranspositionInfo transposeInfo) { + // Grab all differentiable type information. + diffTypeContext.setFunc(revDiffFunc); // Traverse all instructions/blocks in reverse (starting from the terminator inst) // look for insts/blocks marked with IRDifferentialInstDecoration, @@ -103,7 +126,7 @@ struct DiffTransposePass // Set dOutParameter as the transpose gradient for the return inst, if any. if (auto returnInst = as<IRReturn>(block->getTerminator())) { - this->addRevAssignmentForFwdInst(returnInst, transposeInfo.dOutInst); + this->addRevGradientForFwdInst(returnInst, RevGradient(returnInst, transposeInfo.dOutInst, nullptr)); } IRBlock* revBlock = builder.emitBlock(); @@ -117,6 +140,8 @@ struct DiffTransposePass } } + // A[cond_inst] -> (B or C) -> D => D[cond_inst] -> (B_T -> C_T) -> A_T + void transposeBlock(IRBlock* fwdBlock, IRBlock* revBlock) { IRBuilder builder; @@ -125,7 +150,24 @@ struct DiffTransposePass // Insert after the last block. builder.setInsertInto(revBlock); + List<IRInst*> ptrInsts; + for (IRInst* child = fwdBlock->getFirstOrdinaryInst(); child; child = child->getNextInst()) + { + // If the instruction is pointer typed, move to top of new reverse-mode block + if (as<IRPtrTypeBase>(child->getDataType())) + ptrInsts.add(child); + } + + for (auto ptrInst : ptrInsts) + { + ptrInst->insertAtEnd(revBlock); + } + + + // Then, go backwards through the regular instructions, and transpose them into the new + // rev block. // Note the 'reverse' traversal here. + // for (IRInst* child = fwdBlock->getLastChild(); child; child = child->getPrevInst()) { if (as<IRDecoration>(child)) @@ -141,7 +183,7 @@ struct DiffTransposePass // function scope variable, since control flow can affect what blocks contribute to // for a specific inst. // - for (auto pair : assignmentsMap) + for (auto pair : gradientsMap) { if (auto param = as<IRLoad>(pair.Key)) accumulateGradientsForLoad(&builder, param); @@ -163,20 +205,77 @@ struct DiffTransposePass void transposeInst(IRBuilder* builder, IRInst* inst) { - // Look for assignment entry for this inst. - IRInst* revValue = builder->getFloatValue(builder->getType(kIROp_FloatType), 0.0); - if (hasRevAssignments(inst)) + // Look for gradient entries for this inst. + List<RevGradient> gradients; + if (hasRevGradients(inst)) + gradients = popRevGradients(inst); + + // Are we dealing with DifferentialPairType? + if (as<IRDifferentialPairType>(inst->getDataType())) + { + // This will be a 'hybrid' primal-differential inst, + // so we add a pair (primal_value, 0) as an additional + // gradient to represent the primal part of the computation. + // + // Now, if the unzip pass has done it's job, the _only_ + // case should be that inst is IRMakeDifferentialPair + // + SLANG_ASSERT(as<IRMakeDifferentialPair>(inst)); + auto primalType = as<IRDifferentialPairType>(inst->getDataType())->getValueType(); + auto diffType = (IRType*)pairBuilder.getDiffTypeFromPairType(builder, as<IRDifferentialPairType>(inst->getDataType())); + + auto primalInst = as<IRMakeDifferentialPair>(inst)->getPrimalValue(); + auto zeroMethod = diffTypeContext.getZeroMethodForType(builder, primalType); + + // Must exist. + SLANG_ASSERT(zeroMethod); + auto diffInst = builder->emitCallInst(diffType, zeroMethod, List<IRInst*>()); + + gradients.add( + RevGradient( + inst, + builder->emitMakeDifferentialPair(inst->getDataType(), primalInst, diffInst), + nullptr)); + } + + IRType* primalType = tryGetPrimalTypeFromDiffInst(inst); + + if (!primalType) + { + // Special-case instructions. + if (auto returnInst = as<IRReturn>(inst)) + { + auto returnPairType = as<IRDifferentialPairType>( + tryGetPrimalTypeFromDiffInst(returnInst->getVal())); + primalType = returnPairType->getValueType(); + } + } + + if (!primalType) { - // Emit the aggregate of all the assignments here. This will form the derivative - revValue = emitAggregateValue(builder, popRevAssignments(inst)); + // Check for special insts for which a reverse-mode gradient doesn't apply. + if(!as<IRStore>(inst)) + { + SLANG_UNEXPECTED("Could not resolve primal type for diff inst"); + } } + // Emit the aggregate of all the gradients here. This will form the total derivative for this inst. + auto revValue = emitAggregateValue(builder, primalType, gradients); + auto transposeResult = transposeInst(builder, inst, revValue); - // Add the new results to the assignments map. - for (auto pair : transposeResult.revPairs) + if (auto fwdNameHint = inst->findDecoration<IRNameHintDecoration>()) { - addRevAssignmentForFwdInst(pair.lvalue, pair.rvalue); + StringBuilder sb; + sb << fwdNameHint->getName() << "_T"; + builder->addNameHintDecoration(revValue, sb.getUnownedSlice()); + } + + // Add the new results to the gradients map. + for (auto gradient : transposeResult.revPairs) + { + addRevGradientForFwdInst(gradient.targetInst, gradient); } } @@ -189,59 +288,176 @@ struct DiffTransposePass case kIROp_Mul: case kIROp_Sub: return transposeArithmetic(builder, fwdInst, revValue); + + case kIROp_swizzle: + return transposeSwizzle(builder, as<IRSwizzle>(fwdInst), revValue); + + case kIROp_FieldExtract: + return transposeFieldExtract(builder, as<IRFieldExtract>(fwdInst), revValue); case kIROp_Return: return transposeReturn(builder, as<IRReturn>(fwdInst), revValue); + + case kIROp_Store: + return transposeStore(builder, as<IRStore>(fwdInst), revValue); + + case kIROp_Load: + return transposeLoad(builder, as<IRLoad>(fwdInst), revValue); case kIROp_MakeDifferentialPair: return transposeMakePair(builder, as<IRMakeDifferentialPair>(fwdInst), revValue); case kIROp_DifferentialPairGetDifferential: return transposeGetDifferential(builder, as<IRDifferentialPairGetDifferential>(fwdInst), revValue); + + case kIROp_MakeVector: + return transposeMakeVector(builder, fwdInst, revValue); default: SLANG_ASSERT_FAILURE("Unhandled instruction"); } } - TranspositionResult transposeMakePair(IRBuilder*, IRMakeDifferentialPair* fwdMakePair, IRInst* revValue) + TranspositionResult transposeLoad(IRBuilder* builder, IRLoad* fwdLoad, IRInst* revValue) + { + auto revPtr = fwdLoad->getPtr(); + + if (usedPtrs.contains(revPtr)) + { + // Re-emit a load to get the _current_ value of revPtr. + auto revCurrGrad = builder->emitLoad(revPtr); + + // Add the current value to the aggregation list. + List<RevGradient> gradients( + RevGradient( + revCurrGrad, + revValue, + nullptr), + RevGradient( + revCurrGrad, + revCurrGrad, + nullptr)); + + auto primalType = tryGetPrimalTypeFromDiffInst(fwdLoad); + // Get the _total_ value. + auto aggregateGradient = emitAggregateValue(builder, primalType, gradients); + + // Store this back into the pointer. + builder->emitStore(revPtr, aggregateGradient); + } + else + { + usedPtrs.add(revPtr); + + // Store into pointer + builder->emitStore(revPtr, revValue); + } + + return TranspositionResult(List<RevGradient>()); + } + + + TranspositionResult transposeStore(IRBuilder* builder, IRStore* fwdStore, IRInst*) + { + // (A = p.x) -> (p = float3(dA, 0, 0)) + return TranspositionResult( + List<RevGradient>( + RevGradient( + RevGradient::Flavor::Simple, + fwdStore->getVal(), + builder->emitLoad(fwdStore->getPtr()), + fwdStore))); + } + + TranspositionResult transposeSwizzle(IRBuilder*, IRSwizzle* fwdSwizzle, IRInst* revValue) + { + // (A = p.x) -> (p = float3(dA, 0, 0)) + return TranspositionResult( + List<RevGradient>( + RevGradient( + RevGradient::Flavor::Swizzle, + fwdSwizzle->getBase(), + revValue, + fwdSwizzle))); + } + + + TranspositionResult transposeFieldExtract(IRBuilder*, IRFieldExtract* fwdExtract, IRInst* revValue) + { + // (A = p.x) -> (p = float3(dA, 0, 0)) + return TranspositionResult( + List<RevGradient>( + RevGradient( + RevGradient::Flavor::FieldExtract, + fwdExtract->getBase(), + revValue, + fwdExtract))); + } + + TranspositionResult transposeMakePair(IRBuilder* builder, IRMakeDifferentialPair* fwdMakePair, IRInst* revValue) { // (P = (A, dA)) -> (dA += dP) return TranspositionResult( - List<RevAssignment>( - RevAssignment( + List<RevGradient>( + RevGradient( + RevGradient::Flavor::Simple, fwdMakePair->getDifferentialValue(), - revValue))); + builder->emitDifferentialPairGetDifferential( + fwdMakePair->getDifferentialValue()->getDataType(), + revValue), + fwdMakePair))); } TranspositionResult transposeGetDifferential(IRBuilder*, IRDifferentialPairGetDifferential* fwdGetDiff, IRInst* revValue) { // (A = GetDiff(P)) -> (dP.d += dA) return TranspositionResult( - List<RevAssignment>( - RevAssignment( + List<RevGradient>( + RevGradient( + RevGradient::Flavor::GetDifferential, fwdGetDiff->getBase(), - revValue))); + revValue, + fwdGetDiff))); } - // Gather all reverse-mode gradients for parameters, and store to the differential - // - void accumulateGradientsForLoad(IRBuilder* builder, IRLoad* revLoad) + TranspositionResult transposeMakeVector(IRBuilder* builder, IRInst* fwdMakeVector, IRInst* revValue) { - auto revParam = revLoad->getPtr(); + // For now, we support only vector types. Extend this to other built-in types if necessary. + SLANG_ASSERT(fwdMakeVector->getOp() == kIROp_MakeVector); - // Don't currently handle loads from non-param insts. - SLANG_ASSERT(as<IRParam>(revParam)); + List<RevGradient> gradients; + for (UIndex ii = 0; ii < fwdMakeVector->getOperandCount(); ii++) + { + auto gradAtIndex = builder->emitElementExtract( + fwdMakeVector->getOperand(ii)->getDataType(), + revValue, + builder->getIntValue(builder->getIntType(), ii)); + + gradients.add(RevGradient( + RevGradient::Flavor::Simple, + fwdMakeVector->getOperand(ii), + gradAtIndex, + fwdMakeVector)); + } + + // (A = float3(X, Y, Z)) -> [(dX += dA), (dY += dA), (dZ += dA)] + return TranspositionResult(gradients); + } - // Assert that param type is of the form IRPtrTypeBase<IRDifferentialPairType<T>> - SLANG_ASSERT(as<IRPtrTypeBase>(revParam->getDataType())); - SLANG_ASSERT(as<IRPtrTypeBase>(revParam->getDataType())->getValueType()->getOp() == kIROp_DifferentialPairType); + // Gather all reverse-mode gradients for a Load inst, aggregate them and store them in the ptr. + // + void accumulateGradientsForLoad(IRBuilder* builder, IRLoad* revLoad) + { + auto revPtr = revLoad->getPtr(); + + // Assert that ptr type is of the form IRPtrTypeBase<IRDifferentialPairType<T>> + SLANG_ASSERT(as<IRPtrTypeBase>(revPtr->getDataType())); + SLANG_ASSERT(as<IRPtrTypeBase>(revPtr->getDataType())->getValueType()->getOp() == kIROp_DifferentialPairType); - auto paramPairType = as<IRDifferentialPairType>(as<IRPtrTypeBase>(revParam->getDataType())->getValueType()); - auto diffType = (IRType*) pairBuilder.getDiffTypeFromPairType(builder, paramPairType); + auto paramPairType = as<IRDifferentialPairType>(as<IRPtrTypeBase>(revPtr->getDataType())->getValueType()); // Gather gradients. - auto gradients = popRevAssignments(revLoad); + auto gradients = popRevGradients(revLoad); if (gradients.getCount() == 0) { // Ignore. @@ -249,42 +465,42 @@ struct DiffTransposePass } else { - // Re-emit a load to get the _current_ value of revParam. - auto revCurrLoad = builder->emitLoad(revParam); - - // Grab the current gradient value. - auto revCurrGrad = builder->emitDifferentialPairGetDifferential(diffType, revCurrLoad); + // Re-emit a load to get the _current_ value of revPtr. + auto revCurrGrad = builder->emitLoad(revPtr); // Add the current value to the aggregation list. - gradients.add(revCurrGrad); + gradients.add( + RevGradient( + revLoad, + revCurrGrad, + nullptr)); // Get the _total_ value. - auto aggregateGradient = emitAggregateValue(builder, gradients); - - // Grab the current primal value. - auto revCurrPrimal = builder->emitDifferentialPairGetPrimal(revCurrLoad); + auto aggregateGradient = emitAggregateValue(builder, paramPairType, gradients); - // Make the pair with the new gradient. - auto newDiffPair = builder->emitMakeDifferentialPair(paramPairType, revCurrPrimal, aggregateGradient); - - // Store this back into the parameter. - builder->emitStore(revParam, newDiffPair); + // Store this back into the pointer. + builder->emitStore(revPtr, aggregateGradient); } } TranspositionResult transposeReturn(IRBuilder*, IRReturn* fwdReturn, IRInst* revValue) { - + // TODO: This check needs to be changed to something like: isRelevantDifferentialPair() if (as<IRDifferentialPairType>(fwdReturn->getVal()->getDataType())) { - // If the type is a differential pair, we add the reverse-value for the *pair* - // itself. TODO: Signal this through flags in the 'RevAssignment' struct. - // (return (A, dA)) -> (dA += dOut) + // This is a subtle case, even though the returned value is returning + // a pair, we need to pretend that the primal value is not being returned + // since we only care about transposing differential computation. + // So we're going to assume there is an implicit GetDifferential() + // around the return value before returning. + // return TranspositionResult( - List<RevAssignment>( - RevAssignment( + List<RevGradient>( + RevGradient( + RevGradient::Flavor::GetDifferential, fwdReturn->getVal(), - revValue))); + revValue, + fwdReturn))); } else { @@ -293,35 +509,136 @@ struct DiffTransposePass } } + IRInst* promoteToType(IRBuilder* builder, IRType* targetType, IRInst* inst) + { + auto currentType = inst->getDataType(); + + switch (targetType->getOp()) + { + + case kIROp_VectorType: + { + // current type should be a scalar. + SLANG_RELEASE_ASSERT(!as<IRVectorType>(currentType->getDataType())); + + auto targetVectorType = as<IRVectorType>(targetType); + + List<IRInst*> operands; + for (Index ii = 0; ii < as<IRIntLit>(targetVectorType->getElementCount())->getValue(); ii++) + { + operands.add(inst); + } + + IRInst* newInst = builder->emitMakeVector(targetType, operands.getCount(), operands.getBuffer()); + + if (isDifferentialInst(inst)) + builder->markInstAsDifferential(newInst); + + return newInst; + } + + default: + SLANG_ASSERT_FAILURE("Unhandled target type for promotion"); + } + } + + IRInst* promoteOperandsToTargetType(IRBuilder* builder, IRInst* fwdInst) + { + auto oldLoc = builder->getInsertLoc(); + // If operands are not of the same type, cast them to the target type. + IRType* targetType = fwdInst->getDataType(); + + bool needNewInst = false; + + List<IRInst*> newOperands; + for (UIndex ii = 0; ii < fwdInst->getOperandCount(); ii++) + { + auto operand = fwdInst->getOperand(ii); + if (operand->getDataType() != targetType) + { + // Insert new operand just after the old operand, so we have the old + // operands available. + // + builder->setInsertAfter(operand); + + IRInst* newOperand = promoteToType(builder, targetType, operand); + newOperands.add(newOperand); + + needNewInst = true; + } + else + { + newOperands.add(operand); + } + } + + if(needNewInst) + { + builder->setInsertAfter(fwdInst); + IRInst* newInst = builder->emitIntrinsicInst( + fwdInst->getDataType(), + fwdInst->getOp(), + newOperands.getCount(), + newOperands.getBuffer()); + + builder->setInsertLoc(oldLoc); + + if (isDifferentialInst(fwdInst)) + builder->markInstAsDifferential(newInst); + + return newInst; + } + else + { + builder->setInsertLoc(oldLoc); + return fwdInst; + } + } + TranspositionResult transposeArithmetic(IRBuilder* builder, IRInst* fwdInst, IRInst* revValue) { - IRType* floatType = builder->getType(kIROp_FloatType); + + // Only handle arithmetic on uniform types. If the types aren't uniform, we need some + // promotion/demotion logic. Note that this can create a new inst in place of the old, but since we're + // at the transposition step for the old inst, and already have it's aggregate gradient, there's + // no need to worry about the 'gradientsMap' being out-of-date + // TODO: There are some opportunities for optimization here (otherwise we might be increasing the intermediate + // data size unnecessarily) + // + fwdInst = promoteOperandsToTargetType(builder, fwdInst); + + auto operandType = fwdInst->getOperand(0)->getDataType(); + switch(fwdInst->getOp()) { case kIROp_Add: { // (Out = dA + dB) -> [(dA += dOut), (dB += dOut)] return TranspositionResult( - List<RevAssignment>( - RevAssignment( + List<RevGradient>( + RevGradient( fwdInst->getOperand(0), - revValue), - RevAssignment( + revValue, + fwdInst), + RevGradient( fwdInst->getOperand(1), - revValue))); + revValue, + fwdInst))); } case kIROp_Sub: { // (Out = dA - dB) -> [(dA += dOut), (dB -= dOut)] return TranspositionResult( - List<RevAssignment>( - RevAssignment( + List<RevGradient>( + RevGradient( fwdInst->getOperand(0), - revValue), - RevAssignment( + revValue, + fwdInst), + RevGradient( fwdInst->getOperand(1), builder->emitNeg( - revValue->getDataType(), revValue)))); + revValue->getDataType(), revValue), + fwdInst))); } case kIROp_Mul: { @@ -329,19 +646,21 @@ struct DiffTransposePass { // (Out = dA * B) -> (dA += B * dOut) return TranspositionResult( - List<RevAssignment>( - RevAssignment( + List<RevGradient>( + RevGradient( fwdInst->getOperand(0), - builder->emitMul(floatType, fwdInst->getOperand(1), revValue)))); + builder->emitMul(operandType, fwdInst->getOperand(1), revValue), + fwdInst))); } else if (isDifferentialInst(fwdInst->getOperand(1))) { // (Out = A * dB) -> (dB += A * dOut) return TranspositionResult( - List<RevAssignment>( - RevAssignment( + List<RevGradient>( + RevGradient( fwdInst->getOperand(1), - builder->emitMul(floatType, fwdInst->getOperand(0), revValue)))); + builder->emitMul(operandType, fwdInst->getOperand(0), revValue), + fwdInst))); } else { @@ -354,66 +673,397 @@ struct DiffTransposePass } } - IRInst* emitAggregateValue(IRBuilder* builder, List<IRInst*> values) + RevGradient materializeSwizzleGradients(IRBuilder* builder, IRType* aggPrimalType, List<RevGradient> gradients) { - // We're handling the case where the types are all float, - // so we can use a bunch of kIROp_Add insts to add them up. - // If this is an arbitrary type T, we need to lookup and - // call T.dadd() + List<RevGradient> simpleGradients; - IRInst* initialValue = builder->getFloatValue(builder->getType(kIROp_FloatType), 0.0); - if (values.getCount() == 0) + for (auto gradient : gradients) { - // If there's not values to add up, emit a 0 value. - return initialValue; + // Peek at the fwd-mode swizzle inst to see what type we need to materialize. + IRSwizzle* fwdSwizzleInst = as<IRSwizzle>(gradient.fwdGradInst); + SLANG_ASSERT(fwdSwizzleInst); + + auto baseType = fwdSwizzleInst->getBase()->getDataType(); + + // Assume for now that this is a vector type. + SLANG_ASSERT(as<IRVectorType>(baseType)); + + IRInst* elementCountInst = as<IRVectorType>(baseType)->getElementCount(); + IRType* elementType = as<IRVectorType>(baseType)->getElementType(); + + // Must be a concrete integer (auto-diff must always occur after specialization) + // For generic code, we would need to generate a for loop. + // + SLANG_ASSERT(as<IRIntLit>(elementCountInst)); + + auto elementCount = as<IRIntLit>(elementCountInst)->getValue(); + + // Make a list of 0s + List<IRInst*> constructArgs; + auto zeroMethod = diffTypeContext.getZeroMethodForType(builder, elementType); + + // Must exist. + SLANG_ASSERT(zeroMethod); + + auto zeroValueInst = builder->emitCallInst(elementType, zeroMethod, List<IRInst*>()); + + for (Index ii = 0; ii < ((Index)elementCount); ii++) + { + constructArgs.add(zeroValueInst); + } + + // Replace swizzled elements with their gradients. + for (Index ii = 0; ii < ((Index)fwdSwizzleInst->getElementCount()); ii++) + { + auto sourceIndex = ii; + auto targetIndexInst = fwdSwizzleInst->getElementIndex(ii); + SLANG_ASSERT(as<IRIntLit>(targetIndexInst)); + auto targetIndex = as<IRIntLit>(targetIndexInst)->getValue(); + + // Special-case for when the swizzled output is a single element. + if (fwdSwizzleInst->getElementCount() == 1) + { + constructArgs[(Index)targetIndex] = gradient.revGradInst; + } + else + { + auto gradAtIndex = builder->emitElementExtract(elementType, gradient.revGradInst, builder->getIntValue(builder->getIntType(), sourceIndex)); + constructArgs[(Index)targetIndex] = gradAtIndex; + } + } + + simpleGradients.add( + RevGradient( + gradient.targetInst, + builder->emitMakeVector(baseType, (UInt)elementCount, constructArgs.getBuffer()), + gradient.fwdGradInst)); } - else if (values.getCount() == 1) + + return materializeSimpleGradients(builder, aggPrimalType, simpleGradients); + } + + RevGradient materializeGradientSet(IRBuilder* builder, IRType* aggPrimalType, List<RevGradient> gradients) + { + switch (gradients[0].flavor) + { + case RevGradient::Flavor::Simple: + return materializeSimpleGradients(builder, aggPrimalType, gradients); + + case RevGradient::Flavor::Swizzle: + return materializeSwizzleGradients(builder, aggPrimalType, gradients); + + case RevGradient::Flavor::FieldExtract: + return materializeFieldExtractGradients(builder, aggPrimalType, gradients); + + default: + SLANG_ASSERT_FAILURE("Unhandled gradient flavor for materialization"); + } + } + + RevGradient materializeFieldExtractGradients(IRBuilder* builder, IRType* aggPrimalType, List<RevGradient> gradients) + { + // Setup a temporary variable to aggregate gradients. + // TODO: We can extend this later to grab an existing ptr to allow aggregation of + // gradients across blocks without constructing new variables. + // Looking up an existing pointer could also allow chained accesses like x.a.b[1] to directly + // write into the specific sub-field that is affected without constructing intermediate vars. + // + auto revGradVar = builder->emitVar( + (IRType*)diffTypeContext.getDifferentialForType(builder, aggPrimalType)); + + // Initialize with T.dzero() + auto zeroValueInst = emitDZeroOfDiffInstType(builder, aggPrimalType); + + builder->emitStore(revGradVar, zeroValueInst); + + Dictionary<IRStructKey*, List<RevGradient>> bucketedGradients; + for (auto gradient : gradients) + { + // Grab the field affected by this gradient. + auto fieldExtractInst = as<IRFieldExtract>(gradient.fwdGradInst); + SLANG_ASSERT(fieldExtractInst); + + auto structKey = as<IRStructKey>(fieldExtractInst->getField()); + SLANG_ASSERT(structKey); + + if (!bucketedGradients.ContainsKey(structKey)) + { + bucketedGradients[structKey] = List<RevGradient>(); + } + + bucketedGradients[structKey].GetValue().add(RevGradient( + RevGradient::Flavor::Simple, + gradient.targetInst, + gradient.revGradInst, + gradient.fwdGradInst + )); + + } + + for (auto pair : bucketedGradients) + { + auto subGrads = pair.Value; + + auto primalType = tryGetPrimalTypeFromDiffInst(subGrads[0].fwdGradInst); + + SLANG_ASSERT(primalType); + + // Consruct address to this field in revGradVar. + auto revGradTargetAddress = builder->emitFieldAddress( + builder->getPtrType(subGrads[0].revGradInst->getDataType()), + revGradVar, + pair.Key); + + builder->emitStore(revGradTargetAddress, emitAggregateValue(builder, primalType, subGrads)); + } + + // Load the entire var and return it. + return RevGradient( + RevGradient::Flavor::Simple, + gradients[0].targetInst, + builder->emitLoad(revGradVar), + nullptr); + } + + RevGradient materializeSimpleGradients(IRBuilder* builder, IRType* aggPrimalType, List<RevGradient> gradients) + { + if (gradients.getCount() == 1) { // If there's only one value to add up, just return it in order // to avoid a stack of 0 + 0 + 0 + ... - return values[0]; + return gradients[0]; + } + + // If there's more than one gradient, aggregate them by adding them up. + IRInst* currentValue = nullptr; + for (auto gradient : gradients) + { + if (!currentValue) + { + currentValue = gradient.revGradInst; + continue; + } + + currentValue = emitDAddOfDiffInstType(builder, aggPrimalType, currentValue, gradient.revGradInst); } - // If there's more than one value, aggregate them by adding them up. + return RevGradient( + RevGradient::Flavor::Simple, + gradients[0].targetInst, + currentValue, + nullptr); + } + + IRInst* emitAggregateDifferentialPair(IRBuilder* builder, IRType* aggPrimalType, List<RevGradient> pairGradients) + { + auto aggPairType = as<IRDifferentialPairType>(aggPrimalType); + SLANG_ASSERT(aggPairType); - SLANG_ASSERT(values[0]->getDataType()->getOp() == kIROp_FloatType); + IRType* diffType = (IRType*)pairBuilder.getDiffTypeFromPairType(builder, aggPairType); - IRInst* currentValue = initialValue; - for (auto value : values) + IRInst* primalInst = nullptr; + IRInst* diffInst = nullptr; + + List<RevGradient> gradients; + for (auto gradient : pairGradients) { - currentValue = builder->emitAdd( - builder->getType(kIROp_FloatType), currentValue, value); + switch (gradient.flavor) + { + case RevGradient::Flavor::Simple: + { + // In this case, the gradient is a 'pair' already, but we need to treat the primal element + // as if it didn't exist (we simply copy it over) + // If we already saw a pair, throw an error since we don't know how to combine to primals. + // (i.e. something went wrong prior to this step.) + // + if (primalInst) + { + SLANG_UNEXPECTED("Encountered multiple pair types in emitAggregateDifferentialPair"); + } + + primalInst = builder->emitDifferentialPairGetPrimal(gradient.revGradInst); + gradients.add( + RevGradient( + RevGradient::Flavor::Simple, + gradient.targetInst, + builder->emitDifferentialPairGetDifferential( + diffType, + gradient.revGradInst), + gradient.fwdGradInst)); + break; + } + + case RevGradient::Flavor::GetDifferential: + { + // In this case, the gradient is the result of transposing a GetDifferential + // so we have only the gradient part. Just add it to the list of gradients to aggregate + gradients.add( + RevGradient( + RevGradient::Flavor::Simple, + gradient.targetInst, + gradient.revGradInst, + gradient.fwdGradInst)); + break; + } + default: + SLANG_UNEXPECTED("Unexpected gradient flavor in emitAggregateDifferentialPair"); + } } - return currentValue; + // Aggregate only the differentials + diffInst = emitAggregateValue(builder, aggPairType->getValueType(), gradients); + + // Pack them back together. + return builder->emitMakeDifferentialPair(aggPrimalType, primalInst, diffInst); } - void addRevAssignmentForFwdInst(IRInst* fwdInst, IRInst* assignment) + IRInst* emitAggregateValue(IRBuilder* builder, IRType* aggPrimalType, List<RevGradient> gradients) { - if (!hasRevAssignments(fwdInst)) + // If we're dealing with the differential-pair types, we need to use a different aggregation method, since + // a differential pair is really a 'hybrid' primal-differential type. + // + if (as<IRDifferentialPairType>(aggPrimalType)) + return emitAggregateDifferentialPair(builder, aggPrimalType, gradients); + + // Process non-simple gradients into simple gradients. + // TODO: This is where we can improve efficiency later. + // For instance if we have one gradient each for var.x, var.y and var.z + // we can construct one single gradient vector out of the three vectors (i.e. float3(x_grad, y_grad, z_grad)) + // instead of creating one vector for each gradient and accumulating them + // (i.e. float3(x_grad, 0, 0) + float3(0, y_grad, 0) + float3(0, 0, z_grad)) + // The same concept can be extended for struct and array types (and for any combination of the three) + // + List<RevGradient> simpleGradients; { - assignmentsMap[fwdInst] = List<IRInst*>(); + // Start by sorting gradients based on flavor. + gradients.sort([&](const RevGradient& a, const RevGradient& b) -> bool { return a.flavor < b.flavor; }); + + Index ii = 0; + while (ii < gradients.getCount()) + { + List<RevGradient> gradientsOfFlavor; + + RevGradient::Flavor currentFlavor = (gradients.getCount() > 0) ? gradients[ii].flavor : RevGradient::Flavor::Simple; + + // Pull all the gradients matching the flavor of the top-most gradeint into a temporary list. + for (; ii < gradients.getCount(); ii++) + { + if (gradients[ii].flavor == currentFlavor) + { + gradientsOfFlavor.add(gradients[ii]); + } + else + { + break; + } + } + + // Turn the set into a simple gradient. + auto simpleGradient = materializeGradientSet(builder, aggPrimalType, gradientsOfFlavor); + SLANG_ASSERT(simpleGradient.flavor == RevGradient::Flavor::Simple); + + simpleGradients.add(simpleGradient); + } } - assignmentsMap[fwdInst].GetValue().add(assignment); + if (simpleGradients.getCount() == 0) + { + // If there are no gradients to add up, check the type and emit a 0/null value. + auto aggDiffType = (aggPrimalType) ? diffTypeContext.getDifferentialForType(builder, aggPrimalType) : nullptr; + if (aggDiffType != nullptr) + { + // If type is non-null/non-void, call T.dzero() to produce a 0 gradient. + return emitDZeroOfDiffInstType(builder, aggPrimalType); + } + else + { + // Otherwise, gradients may not be applicable for this inst. return N/A + return nullptr; + } + } + else + { + return materializeSimpleGradients(builder, aggPrimalType, simpleGradients).revGradInst; + } + } + + IRType* tryGetPrimalTypeFromDiffInst(IRInst* diffInst) + { + // Look for differential inst decoration. + if (auto diffInstDecoration = diffInst->findDecoration<IRDifferentialInstDecoration>()) + { + return diffInstDecoration->getPrimalType(); + } + else + { + return nullptr; + } + } + + IRInst* emitDZeroOfDiffInstType(IRBuilder* builder, IRType* primalType) + { + auto zeroMethod = diffTypeContext.getZeroMethodForType(builder, primalType); + + // Should exist. + SLANG_ASSERT(zeroMethod); + + return builder->emitCallInst( + (IRType*)diffTypeContext.getDifferentialForType(builder, primalType), + zeroMethod, + List<IRInst*>()); + } + + IRInst* emitDAddOfDiffInstType(IRBuilder* builder, IRType* primalType, IRInst* op1, IRInst* op2) + { + auto addMethod = diffTypeContext.getAddMethodForType(builder, primalType); + + // Should exist. + SLANG_ASSERT(addMethod); + + return builder->emitCallInst( + (IRType*)diffTypeContext.getDifferentialForType(builder, primalType), + addMethod, + List<IRInst*>(op1, op2)); } - List<IRInst*> getRevAssignments(IRInst* fwdInst) + void addRevGradientForFwdInst(IRInst* fwdInst, RevGradient assignment) { - return assignmentsMap[fwdInst]; + if (!hasRevGradients(fwdInst)) + { + gradientsMap[fwdInst] = List<RevGradient>(); + } + + gradientsMap[fwdInst].GetValue().add(assignment); } - List<IRInst*> popRevAssignments(IRInst* fwdInst) + List<RevGradient> getRevGradients(IRInst* fwdInst) { - List<IRInst*> val = assignmentsMap[fwdInst].GetValue(); - assignmentsMap.Remove(fwdInst); + return gradientsMap[fwdInst]; + } + + List<RevGradient> popRevGradients(IRInst* fwdInst) + { + List<RevGradient> val = gradientsMap[fwdInst].GetValue(); + gradientsMap.Remove(fwdInst); return val; } - bool hasRevAssignments(IRInst* fwdInst) + bool hasRevGradients(IRInst* fwdInst) { - return assignmentsMap.ContainsKey(fwdInst); + return gradientsMap.ContainsKey(fwdInst); } + + AutoDiffSharedContext* autodiffContext; + + DifferentiableTypeConformanceContext diffTypeContext; + + DifferentialPairTypeBuilder pairBuilder; + + Dictionary<IRInst*, List<RevGradient>> gradientsMap; + + Dictionary<IRInst*, IRInst*>* primalsMap; + + List<IRInst*> usedPtrs; }; diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h index 344a930f2..79dec365c 100644 --- a/source/slang/slang-ir-autodiff-unzip.h +++ b/source/slang/slang-ir-autodiff-unzip.h @@ -83,15 +83,11 @@ struct DiffUnzipPass if (isDifferentialInst(child) || as<IRTerminatorInst>(child)) { - auto newInst = cloneInst(&cloneEnv, &diffBuilder, child); - child->replaceUsesWith(newInst); - child->removeAndDeallocate(); + child->insertAtEnd(diffBlock); } else { - auto newInst = cloneInst(&cloneEnv, &primalBuilder, child); - child->replaceUsesWith(newInst); - child->removeAndDeallocate(); + child->insertAtEnd(primalBlock); } child = nextChild; diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 5f9ee37fa..5784f60cb 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -735,7 +735,7 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) /// Used by the auto-diff pass to mark insts that compute /// a differential value. - INST(DifferentialInstDecoration, diffInstDecoration, 0, 0) + INST(DifferentialInstDecoration, diffInstDecoration, 1, 0) /// Used by the auto-diff pass to hold a reference to a /// differential member of a type in its associated differential type. diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index a8bc04701..1ef0fa4f8 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -603,7 +603,11 @@ struct IRDifferentialInstDecoration : IRDecoration { kOp = kIROp_DifferentialInstDecoration }; + + IRUse primalType; IR_LEAF_ISA(DifferentialInstDecoration) + + IRType* getPrimalType() { return as<IRType>(getOperand(0)); } }; struct IRBackwardDifferentiableDecoration : IRDecoration @@ -3370,7 +3374,12 @@ public: void markInstAsDifferential(IRInst* value) { - addDecoration(value, kIROp_DifferentialInstDecoration); + addDecoration(value, kIROp_DifferentialInstDecoration, nullptr); + } + + void markInstAsDifferential(IRInst* value, IRType* primalType) + { + addDecoration(value, kIROp_DifferentialInstDecoration, primalType); } void addCOMWitnessDecoration(IRInst* value, IRInst* witnessTable) |
