summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-addr-inst-elimination.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2023-02-01 14:18:57 -0800
committerGitHub <noreply@github.com>2023-02-01 14:18:57 -0800
commitbbd1e1786401bb88c34802b987d4da72e2364503 (patch)
tree99a4be95ae517fd710fc032a1debdac917dd3ac2 /source/slang/slang-ir-addr-inst-elimination.cpp
parentc5895fb0b82fd14fbe45b58d5fc7f75d67625d15 (diff)
Support `out` parameters in backward differentiation. (#2619)
* Support `out` parameters in backward differentiation. * Fixes. * Fix cleanup. --------- Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir-addr-inst-elimination.cpp')
-rw-r--r--source/slang/slang-ir-addr-inst-elimination.cpp31
1 files changed, 27 insertions, 4 deletions
diff --git a/source/slang/slang-ir-addr-inst-elimination.cpp b/source/slang/slang-ir-addr-inst-elimination.cpp
index a5e0e0a4e..a451e24a5 100644
--- a/source/slang/slang-ir-addr-inst-elimination.cpp
+++ b/source/slang/slang-ir-addr-inst-elimination.cpp
@@ -54,11 +54,18 @@ struct AddressInstEliminationContext
}
endLoop:;
auto lastAddr = accessChain.getLast();
- auto lastVal = builder.emitLoad(lastAddr);
accessChain.removeLast();
accessChain.reverse();
- auto update = builder.emitUpdateElement(lastVal, accessChain, val);
- builder.emitStore(lastAddr, update);
+ if (accessChain.getCount())
+ {
+ auto lastVal = builder.emitLoad(lastAddr);
+ auto update = builder.emitUpdateElement(lastVal, accessChain, val);
+ builder.emitStore(lastAddr, update);
+ }
+ else
+ {
+ builder.emitStore(lastAddr, val);
+ }
}
void transformLoadAddr(IRUse* use)
@@ -92,7 +99,22 @@ struct AddressInstEliminationContext
IRBuilder builder(sharedBuilder);
builder.setInsertBefore(call);
auto tempVar = builder.emitVar(cast<IRPtrTypeBase>(addr->getFullType())->getValueType());
- builder.emitStore(tempVar, getValue(builder, addr));
+ auto callee = getResolvedInstForDecorations(call->getCallee());
+ auto funcType = as<IRFuncType>(callee->getFullType());
+ SLANG_RELEASE_ASSERT(funcType);
+ UInt paramIndex = (UInt)(use - call->getOperands() - 1);
+ SLANG_RELEASE_ASSERT(call->getArg(paramIndex) == addr);
+ if (!as<IROutType>(funcType->getParamType(paramIndex)))
+ {
+ builder.emitStore(tempVar, getValue(builder, addr));
+ }
+ else
+ {
+ builder.emitStore(
+ tempVar,
+ builder.emitDefaultConstruct(
+ as<IRPtrTypeBase>(tempVar->getDataType())->getValueType()));
+ }
builder.setInsertAfter(call);
storeValue(builder, addr, builder.emitLoad(tempVar));
use->set(tempVar);
@@ -170,4 +192,5 @@ SlangResult eliminateAddressInsts(
AddressInstEliminationContext ctx;
return ctx.eliminateAddressInstsImpl(sharedBuilder, policy, func, sink);
}
+
} // namespace Slang