diff options
| author | Yong He <yonghe@outlook.com> | 2023-05-18 19:56:44 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-05-18 19:56:44 -0700 |
| commit | 81451fb48a0dbb60cd1d9c806c4cf25085ee5e2a (patch) | |
| tree | 3d0b98d8c8efe3f81091aaa561f8feb0bf352a91 | |
| parent | d8ab2e893f3558b0b0ada5581d4b9e0fe4515d82 (diff) | |
Add transpose logic for scalar swizzle (#2888)
| -rw-r--r-- | source/slang/slang-ir-autodiff-transpose.h | 44 |
1 files changed, 32 insertions, 12 deletions
diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h index c479ea6d1..80934be49 100644 --- a/source/slang/slang-ir-autodiff-transpose.h +++ b/source/slang/slang-ir-autodiff-transpose.h @@ -2289,18 +2289,27 @@ struct DiffTransposePass auto baseType = fwdSwizzleInst->getBase()->getDataType(); // Assume for now that this is a vector type. - SLANG_ASSERT(as<IRVectorType>(baseType)); - - IRInst* elementCountInst = as<IRVectorType>(baseType)->getElementCount(); - IRType* elementType = as<IRVectorType>(baseType)->getElementType(); - - // Must be a concrete integer (auto-diff must always occur after specialization) - // For generic code, we would need to generate a for loop. - // - SLANG_ASSERT(as<IRIntLit>(elementCountInst)); - - auto elementCount = as<IRIntLit>(elementCountInst)->getValue(); + IRIntegerValue elementCount = 0; + IRType* elementType = nullptr; + bool isVectorType = false; + if (auto vectorType = as<IRVectorType>(baseType)) + { + IRInst* elementCountInst = as<IRVectorType>(baseType)->getElementCount(); + elementType = as<IRVectorType>(baseType)->getElementType(); + SLANG_ASSERT(as<IRIntLit>(elementCountInst)); + elementCount = as<IRIntLit>(elementCountInst)->getValue(); + isVectorType = true; + } + else if (auto basicType = as<IRBasicType>(baseType)) + { + elementType = basicType; + elementCount = 1; + } + else + { + SLANG_UNREACHABLE("unknown operand type of swizzle."); + } // Make a list of 0s List<IRInst*> constructArgs; auto zeroMethod = diffTypeContext.getZeroMethodForType(builder, elementType); @@ -2330,7 +2339,18 @@ struct DiffTransposePass } else { - auto gradAtIndex = builder->emitElementExtract(elementType, gradient.revGradInst, builder->getIntValue(builder->getIntType(), sourceIndex)); + IRInst* gradAtIndex = nullptr; + if (isVectorType) + { + gradAtIndex = builder->emitElementExtract( + elementType, + gradient.revGradInst, + builder->getIntValue(builder->getIntType(), sourceIndex)); + } + else + { + gradAtIndex = gradient.revGradInst; + } constructArgs[(Index)targetIndex] = gradAtIndex; } } |
