summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-05-18 19:56:44 -0700
committerGitHub <noreply@github.com>2023-05-18 19:56:44 -0700
commit81451fb48a0dbb60cd1d9c806c4cf25085ee5e2a (patch)
tree3d0b98d8c8efe3f81091aaa561f8feb0bf352a91
parentd8ab2e893f3558b0b0ada5581d4b9e0fe4515d82 (diff)
Add transpose logic for scalar swizzle (#2888)
-rw-r--r--source/slang/slang-ir-autodiff-transpose.h44
1 files changed, 32 insertions, 12 deletions
diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h
index c479ea6d1..80934be49 100644
--- a/source/slang/slang-ir-autodiff-transpose.h
+++ b/source/slang/slang-ir-autodiff-transpose.h
@@ -2289,18 +2289,27 @@ struct DiffTransposePass
auto baseType = fwdSwizzleInst->getBase()->getDataType();
// Assume for now that this is a vector type.
- SLANG_ASSERT(as<IRVectorType>(baseType));
-
- IRInst* elementCountInst = as<IRVectorType>(baseType)->getElementCount();
- IRType* elementType = as<IRVectorType>(baseType)->getElementType();
-
- // Must be a concrete integer (auto-diff must always occur after specialization)
- // For generic code, we would need to generate a for loop.
- //
- SLANG_ASSERT(as<IRIntLit>(elementCountInst));
-
- auto elementCount = as<IRIntLit>(elementCountInst)->getValue();
+ IRIntegerValue elementCount = 0;
+ IRType* elementType = nullptr;
+ bool isVectorType = false;
+ if (auto vectorType = as<IRVectorType>(baseType))
+ {
+ IRInst* elementCountInst = as<IRVectorType>(baseType)->getElementCount();
+ elementType = as<IRVectorType>(baseType)->getElementType();
+ SLANG_ASSERT(as<IRIntLit>(elementCountInst));
+ elementCount = as<IRIntLit>(elementCountInst)->getValue();
+ isVectorType = true;
+ }
+ else if (auto basicType = as<IRBasicType>(baseType))
+ {
+ elementType = basicType;
+ elementCount = 1;
+ }
+ else
+ {
+ SLANG_UNREACHABLE("unknown operand type of swizzle.");
+ }
// Make a list of 0s
List<IRInst*> constructArgs;
auto zeroMethod = diffTypeContext.getZeroMethodForType(builder, elementType);
@@ -2330,7 +2339,18 @@ struct DiffTransposePass
}
else
{
- auto gradAtIndex = builder->emitElementExtract(elementType, gradient.revGradInst, builder->getIntValue(builder->getIntType(), sourceIndex));
+ IRInst* gradAtIndex = nullptr;
+ if (isVectorType)
+ {
+ gradAtIndex = builder->emitElementExtract(
+ elementType,
+ gradient.revGradInst,
+ builder->getIntValue(builder->getIntType(), sourceIndex));
+ }
+ else
+ {
+ gradAtIndex = gradient.revGradInst;
+ }
constructArgs[(Index)targetIndex] = gradAtIndex;
}
}