diff options
Diffstat (limited to 'source/slang')
| -rw-r--r-- | source/slang/slang-ir-autodiff-unzip.h | 28 |
1 files changed, 25 insertions, 3 deletions
diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h index c57dc300f..771a3977e 100644 --- a/source/slang/slang-ir-autodiff-unzip.h +++ b/source/slang/slang-ir-autodiff-unzip.h @@ -372,10 +372,32 @@ struct DiffUnzipPass } else { - // For non differentiable arguments, we can simply pass the argument as is - // if this isn't a `out` parameter, in which case it is removed from propagate call. - if (!as<IROutType>(arg->getDataType())) + if (auto inOutType = as<IRInOutType>(resolvedPrimalFuncType->getParamType(ii))) + { + // For 'inout' parameter we need to create a temp var to hold the value + // before the primal call. This logic is similar to the 'inout' case for differentiable params + // only we don't need to deal with pair types. + // + auto tempPrimalVar = primalBuilder->emitVar(as<IRPtrTypeBase>(arg->getDataType())->getValueType()); + + auto storeUse = findUniqueStoredVal(cast<IRVar>(arg)); + auto storeInst = cast<IRStore>(storeUse->getUser()); + auto storedVal = storeInst->getVal(); + + primalBuilder->emitStore(tempPrimalVar, storedVal); + + diffArgs.add(tempPrimalVar); + } + else + { + // For pure 'in' type. Simply re-use the original argument inst. + // + // For 'out' type parameters, it doesn't really matter what we pass in here, since + // the tranposition logic will discard the argument anyway (we'll pass in the old arg, + // just to keep the number of arguments consistent) + // diffArgs.add(arg); + } } } |
