diff options
| author | Yong He <yonghe@outlook.com> | 2023-02-03 16:44:33 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-02-03 16:44:33 -0800 |
| commit | 228e71dab7dfa18ece979f4099ec0c7d1e37e5ff (patch) | |
| tree | ff357f4aaed2dab25ae9e3665a97a7f3e6be32ef /source/slang/slang-ir.cpp | |
| parent | ee49a62083d28353812185fd0f0c04fb50ca6be0 (diff) | |
Overhaul `transposeParameterBlock` to support `inout` params. (#2621)
* Overhaul `transposeParameterBlock` to support `inout` params.
* Small bug fixes.
* Bug fix on differentiable intrinsic specialization.
* Fixes.
* Run autodiff tests on CPU.
* Clean up.
* More bug fixes.,
* Add test coverage on inout param.
* Fix language server hinting for transcribed mutable params.
---------
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-ir.cpp')
| -rw-r--r-- | source/slang/slang-ir.cpp | 50 |
1 files changed, 50 insertions, 0 deletions
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 4814726cf..558574bf6 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -4234,6 +4234,50 @@ namespace Slang return inst; } + IRInst* IRBuilder::emitReverseGradientDiffPairRef(IRType* type, IRInst* primalVar, IRInst* diffVar) + { + auto inst = createInst<IRReverseGradientDiffPairRef>( + this, + kIROp_ReverseGradientDiffPairRef, + type, + primalVar, + diffVar); + + addInst(inst); + return inst; + } + + IRInst* IRBuilder::emitPrimalParamRef(IRInst* param) + { + auto type = param->getFullType(); + auto ptrType = as<IRPtrTypeBase>(type); + auto valueType = type; + if (ptrType) valueType = ptrType->getValueType(); + auto pairType = as<IRDifferentialPairType>(valueType); + IRType* finalType = pairType->getValueType(); + if (ptrType) finalType = getPtrType(ptrType->getOp(), finalType); + auto inst = createInst<IRPrimalParamRef>( + this, + kIROp_PrimalParamRef, + finalType, + param); + + addInst(inst); + return inst; + } + + IRInst* IRBuilder::emitDiffParamRef(IRType* type, IRInst* param) + { + auto inst = createInst<IRDiffParamRef>( + this, + kIROp_DiffParamRef, + type, + param); + + addInst(inst); + return inst; + } + IRInst* IRBuilder::emitLoad( IRType* type, IRInst* ptr) @@ -6753,6 +6797,8 @@ namespace Slang // common subexpression elimination, etc. // auto call = cast<IRCall>(this); + if (call->findDecoration<IRNoSideEffectDecoration>()) + return false; return !isPureFunctionalCall(call); } break; @@ -6809,10 +6855,14 @@ namespace Slang case kIROp_MakeOptionalNone: case kIROp_OptionalHasValue: case kIROp_GetOptionalValue: + case kIROp_DifferentialPairGetPrimal: + case kIROp_DifferentialPairGetDifferential: + case kIROp_MakeDifferentialPair: case kIROp_MakeTuple: case kIROp_GetTupleElement: case kIROp_Load: // We are ignoring the possibility of loads from bad addresses, or `volatile` loads case kIROp_LoadReverseGradient: + case kIROp_ReverseGradientDiffPairRef: case kIROp_ImageSubscript: case kIROp_FieldExtract: case kIROp_FieldAddress: |
