From 81451fb48a0dbb60cd1d9c806c4cf25085ee5e2a Mon Sep 17 00:00:00 2001 From: Yong He Date: Thu, 18 May 2023 19:56:44 -0700 Subject: Add transpose logic for scalar swizzle (#2888) --- source/slang/slang-ir-autodiff-transpose.h | 44 ++++++++++++++++++++++-------- 1 file changed, 32 insertions(+), 12 deletions(-) (limited to 'source') diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h index c479ea6d1..80934be49 100644 --- a/source/slang/slang-ir-autodiff-transpose.h +++ b/source/slang/slang-ir-autodiff-transpose.h @@ -2289,18 +2289,27 @@ struct DiffTransposePass auto baseType = fwdSwizzleInst->getBase()->getDataType(); // Assume for now that this is a vector type. - SLANG_ASSERT(as(baseType)); - - IRInst* elementCountInst = as(baseType)->getElementCount(); - IRType* elementType = as(baseType)->getElementType(); - - // Must be a concrete integer (auto-diff must always occur after specialization) - // For generic code, we would need to generate a for loop. - // - SLANG_ASSERT(as(elementCountInst)); - - auto elementCount = as(elementCountInst)->getValue(); + IRIntegerValue elementCount = 0; + IRType* elementType = nullptr; + bool isVectorType = false; + if (auto vectorType = as(baseType)) + { + IRInst* elementCountInst = as(baseType)->getElementCount(); + elementType = as(baseType)->getElementType(); + SLANG_ASSERT(as(elementCountInst)); + elementCount = as(elementCountInst)->getValue(); + isVectorType = true; + } + else if (auto basicType = as(baseType)) + { + elementType = basicType; + elementCount = 1; + } + else + { + SLANG_UNREACHABLE("unknown operand type of swizzle."); + } // Make a list of 0s List constructArgs; auto zeroMethod = diffTypeContext.getZeroMethodForType(builder, elementType); @@ -2330,7 +2339,18 @@ struct DiffTransposePass } else { - auto gradAtIndex = builder->emitElementExtract(elementType, gradient.revGradInst, builder->getIntValue(builder->getIntType(), sourceIndex)); + IRInst* gradAtIndex = nullptr; + if (isVectorType) + { + gradAtIndex = builder->emitElementExtract( + elementType, + gradient.revGradInst, + builder->getIntValue(builder->getIntType(), sourceIndex)); + } + else + { + gradAtIndex = gradient.revGradInst; + } constructArgs[(Index)targetIndex] = gradAtIndex; } } -- cgit v1.2.3