summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/slang-ir-autodiff-transpose.h44
1 files changed, 32 insertions, 12 deletions
diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h
index c479ea6d1..80934be49 100644
--- a/source/slang/slang-ir-autodiff-transpose.h
+++ b/source/slang/slang-ir-autodiff-transpose.h
@@ -2289,18 +2289,27 @@ struct DiffTransposePass
auto baseType = fwdSwizzleInst->getBase()->getDataType();
// Assume for now that this is a vector type.
- SLANG_ASSERT(as<IRVectorType>(baseType));
-
- IRInst* elementCountInst = as<IRVectorType>(baseType)->getElementCount();
- IRType* elementType = as<IRVectorType>(baseType)->getElementType();
-
- // Must be a concrete integer (auto-diff must always occur after specialization)
- // For generic code, we would need to generate a for loop.
- //
- SLANG_ASSERT(as<IRIntLit>(elementCountInst));
-
- auto elementCount = as<IRIntLit>(elementCountInst)->getValue();
+ IRIntegerValue elementCount = 0;
+ IRType* elementType = nullptr;
+ bool isVectorType = false;
+ if (auto vectorType = as<IRVectorType>(baseType))
+ {
+ IRInst* elementCountInst = as<IRVectorType>(baseType)->getElementCount();
+ elementType = as<IRVectorType>(baseType)->getElementType();
+ SLANG_ASSERT(as<IRIntLit>(elementCountInst));
+ elementCount = as<IRIntLit>(elementCountInst)->getValue();
+ isVectorType = true;
+ }
+ else if (auto basicType = as<IRBasicType>(baseType))
+ {
+ elementType = basicType;
+ elementCount = 1;
+ }
+ else
+ {
+ SLANG_UNREACHABLE("unknown operand type of swizzle.");
+ }
// Make a list of 0s
List<IRInst*> constructArgs;
auto zeroMethod = diffTypeContext.getZeroMethodForType(builder, elementType);
@@ -2330,7 +2339,18 @@ struct DiffTransposePass
}
else
{
- auto gradAtIndex = builder->emitElementExtract(elementType, gradient.revGradInst, builder->getIntValue(builder->getIntType(), sourceIndex));
+ IRInst* gradAtIndex = nullptr;
+ if (isVectorType)
+ {
+ gradAtIndex = builder->emitElementExtract(
+ elementType,
+ gradient.revGradInst,
+ builder->getIntValue(builder->getIntType(), sourceIndex));
+ }
+ else
+ {
+ gradAtIndex = gradient.revGradInst;
+ }
constructArgs[(Index)targetIndex] = gradAtIndex;
}
}