summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/slang-ir-autodiff-transpose.h133
1 files changed, 72 insertions, 61 deletions
diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h
index 80934be49..a0f888931 100644
--- a/source/slang/slang-ir-autodiff-transpose.h
+++ b/source/slang/slang-ir-autodiff-transpose.h
@@ -2280,49 +2280,63 @@ struct DiffTransposePass
{
List<RevGradient> simpleGradients;
- for (auto gradient : gradients)
+ SLANG_ASSERT(gradients.getCount() > 0);
+
+ auto firstGradient = gradients[0];
+ auto firstFwdSwizzleInst = as<IRSwizzle>(firstGradient.fwdGradInst);
+ SLANG_ASSERT(firstFwdSwizzleInst);
+
+ auto baseType = firstFwdSwizzleInst->getBase()->getDataType();
+
+ IRIntegerValue elementCount = 0;
+ IRType* elementType = nullptr;
+ IRType* primalElementType = nullptr;
+ bool isVectorType = false;
+
+ if (auto vectorType = as<IRVectorType>(baseType))
+ {
+ IRInst* elementCountInst = as<IRVectorType>(baseType)->getElementCount();
+ elementType = as<IRVectorType>(baseType)->getElementType();
+ primalElementType = as<IRVectorType>(aggPrimalType)->getElementType();
+ SLANG_ASSERT(as<IRIntLit>(elementCountInst));
+ elementCount = as<IRIntLit>(elementCountInst)->getValue();
+ isVectorType = true;
+ }
+ else if (auto basicType = as<IRBasicType>(baseType))
+ {
+ elementType = basicType;
+ primalElementType = aggPrimalType;
+ elementCount = 1;
+ }
+ else
{
- // Peek at the fwd-mode swizzle inst to see what type we need to materialize.
- IRSwizzle* fwdSwizzleInst = as<IRSwizzle>(gradient.fwdGradInst);
- SLANG_ASSERT(fwdSwizzleInst);
+ SLANG_UNREACHABLE("unknown operand type of swizzle.");
+ }
- auto baseType = fwdSwizzleInst->getBase()->getDataType();
+ IRInst* targetInst = firstGradient.targetInst;
- // Assume for now that this is a vector type.
- IRIntegerValue elementCount = 0;
- IRType* elementType = nullptr;
- bool isVectorType = false;
+ // Make a list of zeros of the base type.
+ auto zeroElement = emitDZeroOfDiffInstType(builder, primalElementType);
- if (auto vectorType = as<IRVectorType>(baseType))
- {
- IRInst* elementCountInst = as<IRVectorType>(baseType)->getElementCount();
- elementType = as<IRVectorType>(baseType)->getElementType();
- SLANG_ASSERT(as<IRIntLit>(elementCountInst));
- elementCount = as<IRIntLit>(elementCountInst)->getValue();
- isVectorType = true;
- }
- else if (auto basicType = as<IRBasicType>(baseType))
- {
- elementType = basicType;
- elementCount = 1;
- }
+ List<IRInst*> elementGrads;
+ for (Index i = 0; i < elementCount; ++i)
+ elementGrads.add(zeroElement);
+
+ auto accGrad = [&](UIndex i, IRInst* grad)
+ {
+ if (elementGrads[i] == zeroElement)
+ elementGrads[i] = grad;
else
- {
- SLANG_UNREACHABLE("unknown operand type of swizzle.");
- }
- // Make a list of 0s
- List<IRInst*> constructArgs;
- auto zeroMethod = diffTypeContext.getZeroMethodForType(builder, elementType);
+ elementGrads[i] = emitDAddOfDiffInstType(builder, primalElementType, elementGrads[i], grad);
+ };
- // Must exist.
- SLANG_ASSERT(zeroMethod);
+ for (auto gradient : gradients)
+ {
+ SLANG_ASSERT(gradient.targetInst == targetInst);
- auto zeroValueInst = builder->emitCallInst(elementType, zeroMethod, List<IRInst*>());
-
- for (Index ii = 0; ii < ((Index)elementCount); ii++)
- {
- constructArgs.add(zeroValueInst);
- }
+ auto fwdSwizzleInst = as<IRSwizzle>(gradient.fwdGradInst);
+ SLANG_ASSERT(as<IRSwizzle>(gradient.fwdGradInst));
+ SLANG_ASSERT(as<IRSwizzle>(gradient.fwdGradInst)->getBase() == firstFwdSwizzleInst->getBase());
// Replace swizzled elements with their gradients.
for (Index ii = 0; ii < ((Index)fwdSwizzleInst->getElementCount()); ii++)
@@ -2332,37 +2346,34 @@ struct DiffTransposePass
SLANG_ASSERT(as<IRIntLit>(targetIndexInst));
auto targetIndex = as<IRIntLit>(targetIndexInst)->getValue();
- // Special-case for when the swizzled output is a single element.
+ // Case 1: Swizzled output is a single element,
if (fwdSwizzleInst->getElementCount() == 1)
- {
- constructArgs[(Index)targetIndex] = gradient.revGradInst;
- }
- else
- {
- IRInst* gradAtIndex = nullptr;
- if (isVectorType)
- {
- gradAtIndex = builder->emitElementExtract(
+ accGrad((UIndex)targetIndex, gradient.revGradInst);
+ // Case 2: Swizzled output is a vector, so we need to extract the element.
+ else if (isVectorType)
+ accGrad((UIndex)targetIndex,
+ builder->emitElementExtract(
elementType,
gradient.revGradInst,
- builder->getIntValue(builder->getIntType(), sourceIndex));
- }
- else
- {
- gradAtIndex = gradient.revGradInst;
- }
- constructArgs[(Index)targetIndex] = gradAtIndex;
- }
+ builder->getIntValue(
+ builder->getIntType(),
+ sourceIndex)));
+ // Case 3: Swizzled input is a scalar.
+ else
+ accGrad((UIndex)targetIndex, gradient.revGradInst);
}
-
- simpleGradients.add(
- RevGradient(
- gradient.targetInst,
- builder->emitMakeVector(baseType, (UInt)elementCount, constructArgs.getBuffer()),
- gradient.fwdGradInst));
}
- return materializeSimpleGradients(builder, aggPrimalType, simpleGradients);
+ if (isVectorType)
+ return RevGradient(
+ targetInst,
+ builder->emitMakeVector(baseType, (UInt)elementCount, elementGrads.getBuffer()),
+ nullptr);
+ else
+ return RevGradient(
+ targetInst,
+ elementGrads[0],
+ nullptr);
}
RevGradient materializeDifferentialPairUserCodeGetElementGradients(IRBuilder* builder, IRType* aggPrimalType, List<RevGradient> gradients)