diff options
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-ir-autodiff-transpose.h | 44 |
1 files changed, 32 insertions, 12 deletions
diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h index c479ea6d1..80934be49 100644 --- a/source/slang/slang-ir-autodiff-transpose.h +++ b/source/slang/slang-ir-autodiff-transpose.h @@ -2289,18 +2289,27 @@ struct DiffTransposePass auto baseType = fwdSwizzleInst->getBase()->getDataType(); // Assume for now that this is a vector type. - SLANG_ASSERT(as<IRVectorType>(baseType)); - - IRInst* elementCountInst = as<IRVectorType>(baseType)->getElementCount(); - IRType* elementType = as<IRVectorType>(baseType)->getElementType(); - - // Must be a concrete integer (auto-diff must always occur after specialization) - // For generic code, we would need to generate a for loop. - // - SLANG_ASSERT(as<IRIntLit>(elementCountInst)); - - auto elementCount = as<IRIntLit>(elementCountInst)->getValue(); + IRIntegerValue elementCount = 0; + IRType* elementType = nullptr; + bool isVectorType = false; + if (auto vectorType = as<IRVectorType>(baseType)) + { + IRInst* elementCountInst = as<IRVectorType>(baseType)->getElementCount(); + elementType = as<IRVectorType>(baseType)->getElementType(); + SLANG_ASSERT(as<IRIntLit>(elementCountInst)); + elementCount = as<IRIntLit>(elementCountInst)->getValue(); + isVectorType = true; + } + else if (auto basicType = as<IRBasicType>(baseType)) + { + elementType = basicType; + elementCount = 1; + } + else + { + SLANG_UNREACHABLE("unknown operand type of swizzle."); + } // Make a list of 0s List<IRInst*> constructArgs; auto zeroMethod = diffTypeContext.getZeroMethodForType(builder, elementType); @@ -2330,7 +2339,18 @@ struct DiffTransposePass } else { - auto gradAtIndex = builder->emitElementExtract(elementType, gradient.revGradInst, builder->getIntValue(builder->getIntType(), sourceIndex)); + IRInst* gradAtIndex = nullptr; + if (isVectorType) + { + gradAtIndex = builder->emitElementExtract( + elementType, + gradient.revGradInst, + builder->getIntValue(builder->getIntType(), sourceIndex)); + } + else + { + gradAtIndex = gradient.revGradInst; + } constructArgs[(Index)targetIndex] = gradAtIndex; } } |
