diff options
| author | Yong He <yonghe@outlook.com> | 2024-08-28 09:23:08 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-08-28 09:23:08 -0700 |
| commit | 638e5fb000d4e242a91e8b653da4a72daec0efda (patch) | |
| tree | cfcd15c1fc6bdee624eb33abac3268241b086dec /source/slang/slang-ir-lower-tuple-types.cpp | |
| parent | 16595a8379e9dbfa1845fd72f3531ff3372da3ef (diff) | |
Make tuple types work in autodiff. (#4923)
Diffstat (limited to 'source/slang/slang-ir-lower-tuple-types.cpp')
| -rw-r--r-- | source/slang/slang-ir-lower-tuple-types.cpp | 68 |
1 files changed, 68 insertions, 0 deletions
diff --git a/source/slang/slang-ir-lower-tuple-types.cpp b/source/slang/slang-ir-lower-tuple-types.cpp index 6177cfec2..91d6bfc29 100644 --- a/source/slang/slang-ir-lower-tuple-types.cpp +++ b/source/slang/slang-ir-lower-tuple-types.cpp @@ -262,6 +262,71 @@ namespace Slang inst->removeAndDeallocate(); } + void processUpdateElement(IRUpdateElement* inst) + { + // For UpdateElement insts, we need to figure out all the intermediate types on the access chain, + // and if any of them are lowered tuples, we need to replace the access key with the new struct + // key for the lowered tuple struct. + // + ShortList<IRInst*> newAccessChain; + bool accessChainChanged = false; + auto baseType = inst->getOldValue()->getDataType(); + IRBuilder builder(inst); + builder.setInsertBefore(inst); + + for (UInt i = 0; i < inst->getAccessKeyCount(); i++) + { + auto key = inst->getAccessKey(i); + if (auto structKey = as<IRStructKey>(key)) + { + if (auto structType = as<IRStructType>(baseType)) + { + auto field = findStructField(structType, structKey); + baseType = field->getFieldType(); + newAccessChain.add(structKey); + } + else + { + // If we see anything not supported, just bail out. + return; + } + } + else if (auto arrayType = as<IRArrayTypeBase>(baseType)) + { + baseType = arrayType->getElementType(); + newAccessChain.add(key); + } + else if (auto loweredTupleInfo = getLoweredTupleType(&builder, baseType)) + { + auto fieldIndex = getIntVal(key); + if (fieldIndex >= 0 && (Index)fieldIndex < loweredTupleInfo->fields.getCount()) + { + auto field = loweredTupleInfo->fields[fieldIndex]; + baseType = field->getFieldType(); + newAccessChain.add(field->getKey()); + accessChainChanged = true; + } + else + { + // If we see anything not supported, just bail out. + break; + } + } + else + { + // If we see anything not supported, just bail out. + break; + } + } + + if (accessChainChanged) + { + auto newInst = builder.emitUpdateElement(inst->getOldValue(), newAccessChain.getArrayView().arrayView, inst->getElementValue()); + inst->replaceUsesWith(newInst); + inst->removeAndDeallocate(); + } + } + void processInst(IRInst* inst) { switch (inst->getOp()) @@ -291,6 +356,9 @@ namespace Slang case kIROp_IndexedFieldKey: processIndexedFieldKey((IRIndexedFieldKey*)inst); break; + case kIROp_UpdateElement: + processUpdateElement((IRUpdateElement*)inst); + break; default: break; } |
