diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2023-06-21 19:18:30 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-06-21 19:18:30 -0400 |
| commit | ac541d45fafde340b141172cf76d003ff70d471e (patch) | |
| tree | e41c40b6fee55e2287fe12a4653cba25a97bcbec | |
| parent | 79b0a2a555d17bb2fd3f391be83bab4809288075 (diff) | |
Avoid materializing multiple swizzle gradients (#2923)
| -rw-r--r-- | source/slang/slang-ir-autodiff-transpose.h | 133 |
1 files changed, 72 insertions, 61 deletions
diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h index 80934be49..a0f888931 100644 --- a/source/slang/slang-ir-autodiff-transpose.h +++ b/source/slang/slang-ir-autodiff-transpose.h @@ -2280,49 +2280,63 @@ struct DiffTransposePass { List<RevGradient> simpleGradients; - for (auto gradient : gradients) + SLANG_ASSERT(gradients.getCount() > 0); + + auto firstGradient = gradients[0]; + auto firstFwdSwizzleInst = as<IRSwizzle>(firstGradient.fwdGradInst); + SLANG_ASSERT(firstFwdSwizzleInst); + + auto baseType = firstFwdSwizzleInst->getBase()->getDataType(); + + IRIntegerValue elementCount = 0; + IRType* elementType = nullptr; + IRType* primalElementType = nullptr; + bool isVectorType = false; + + if (auto vectorType = as<IRVectorType>(baseType)) + { + IRInst* elementCountInst = as<IRVectorType>(baseType)->getElementCount(); + elementType = as<IRVectorType>(baseType)->getElementType(); + primalElementType = as<IRVectorType>(aggPrimalType)->getElementType(); + SLANG_ASSERT(as<IRIntLit>(elementCountInst)); + elementCount = as<IRIntLit>(elementCountInst)->getValue(); + isVectorType = true; + } + else if (auto basicType = as<IRBasicType>(baseType)) + { + elementType = basicType; + primalElementType = aggPrimalType; + elementCount = 1; + } + else { - // Peek at the fwd-mode swizzle inst to see what type we need to materialize. - IRSwizzle* fwdSwizzleInst = as<IRSwizzle>(gradient.fwdGradInst); - SLANG_ASSERT(fwdSwizzleInst); + SLANG_UNREACHABLE("unknown operand type of swizzle."); + } - auto baseType = fwdSwizzleInst->getBase()->getDataType(); + IRInst* targetInst = firstGradient.targetInst; - // Assume for now that this is a vector type. - IRIntegerValue elementCount = 0; - IRType* elementType = nullptr; - bool isVectorType = false; + // Make a list of zeros of the base type. + auto zeroElement = emitDZeroOfDiffInstType(builder, primalElementType); - if (auto vectorType = as<IRVectorType>(baseType)) - { - IRInst* elementCountInst = as<IRVectorType>(baseType)->getElementCount(); - elementType = as<IRVectorType>(baseType)->getElementType(); - SLANG_ASSERT(as<IRIntLit>(elementCountInst)); - elementCount = as<IRIntLit>(elementCountInst)->getValue(); - isVectorType = true; - } - else if (auto basicType = as<IRBasicType>(baseType)) - { - elementType = basicType; - elementCount = 1; - } + List<IRInst*> elementGrads; + for (Index i = 0; i < elementCount; ++i) + elementGrads.add(zeroElement); + + auto accGrad = [&](UIndex i, IRInst* grad) + { + if (elementGrads[i] == zeroElement) + elementGrads[i] = grad; else - { - SLANG_UNREACHABLE("unknown operand type of swizzle."); - } - // Make a list of 0s - List<IRInst*> constructArgs; - auto zeroMethod = diffTypeContext.getZeroMethodForType(builder, elementType); + elementGrads[i] = emitDAddOfDiffInstType(builder, primalElementType, elementGrads[i], grad); + }; - // Must exist. - SLANG_ASSERT(zeroMethod); + for (auto gradient : gradients) + { + SLANG_ASSERT(gradient.targetInst == targetInst); - auto zeroValueInst = builder->emitCallInst(elementType, zeroMethod, List<IRInst*>()); - - for (Index ii = 0; ii < ((Index)elementCount); ii++) - { - constructArgs.add(zeroValueInst); - } + auto fwdSwizzleInst = as<IRSwizzle>(gradient.fwdGradInst); + SLANG_ASSERT(as<IRSwizzle>(gradient.fwdGradInst)); + SLANG_ASSERT(as<IRSwizzle>(gradient.fwdGradInst)->getBase() == firstFwdSwizzleInst->getBase()); // Replace swizzled elements with their gradients. for (Index ii = 0; ii < ((Index)fwdSwizzleInst->getElementCount()); ii++) @@ -2332,37 +2346,34 @@ struct DiffTransposePass SLANG_ASSERT(as<IRIntLit>(targetIndexInst)); auto targetIndex = as<IRIntLit>(targetIndexInst)->getValue(); - // Special-case for when the swizzled output is a single element. + // Case 1: Swizzled output is a single element, if (fwdSwizzleInst->getElementCount() == 1) - { - constructArgs[(Index)targetIndex] = gradient.revGradInst; - } - else - { - IRInst* gradAtIndex = nullptr; - if (isVectorType) - { - gradAtIndex = builder->emitElementExtract( + accGrad((UIndex)targetIndex, gradient.revGradInst); + // Case 2: Swizzled output is a vector, so we need to extract the element. + else if (isVectorType) + accGrad((UIndex)targetIndex, + builder->emitElementExtract( elementType, gradient.revGradInst, - builder->getIntValue(builder->getIntType(), sourceIndex)); - } - else - { - gradAtIndex = gradient.revGradInst; - } - constructArgs[(Index)targetIndex] = gradAtIndex; - } + builder->getIntValue( + builder->getIntType(), + sourceIndex))); + // Case 3: Swizzled input is a scalar. + else + accGrad((UIndex)targetIndex, gradient.revGradInst); } - - simpleGradients.add( - RevGradient( - gradient.targetInst, - builder->emitMakeVector(baseType, (UInt)elementCount, constructArgs.getBuffer()), - gradient.fwdGradInst)); } - return materializeSimpleGradients(builder, aggPrimalType, simpleGradients); + if (isVectorType) + return RevGradient( + targetInst, + builder->emitMakeVector(baseType, (UInt)elementCount, elementGrads.getBuffer()), + nullptr); + else + return RevGradient( + targetInst, + elementGrads[0], + nullptr); } RevGradient materializeDifferentialPairUserCodeGetElementGradients(IRBuilder* builder, IRType* aggPrimalType, List<RevGradient> gradients) |
