diff options
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-ir-autodiff-transpose.h | 7 |
1 files changed, 5 insertions, 2 deletions
diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h index bcebd2108..a3a4eb2b3 100644 --- a/source/slang/slang-ir-autodiff-transpose.h +++ b/source/slang/slang-ir-autodiff-transpose.h @@ -1841,6 +1841,7 @@ struct DiffTransposePass TranspositionResult transposeMakeVector(IRBuilder* builder, IRInst* fwdMakeVector, IRInst* revValue) { List<RevGradient> gradients; + UInt offset = 0; for (UIndex ii = 0; ii < fwdMakeVector->getOperandCount(); ii++) { auto argOperand = fwdMakeVector->getOperand(ii); @@ -1857,12 +1858,12 @@ struct DiffTransposePass gradAtIndex = builder->emitElementExtract( argOperand->getDataType(), revValue, - builder->getIntValue(builder->getIntType(), ii)); + builder->getIntValue(builder->getIntType(), offset)); } else { ShortList<UInt> componentIndices; - for (UInt index = ii; index < ii + componentCount; index++) + for (UInt index = offset; index < offset + componentCount; index++) componentIndices.add(index); gradAtIndex = builder->emitSwizzle( argOperand->getDataType(), @@ -1876,6 +1877,8 @@ struct DiffTransposePass fwdMakeVector->getOperand(ii), gradAtIndex, fwdMakeVector)); + + offset += componentCount; } // (A = float3(X, Y, Z)) -> [(dX += dA), (dY += dA), (dZ += dA)] |
