diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2023-04-19 12:49:14 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-04-19 09:49:14 -0700 |
| commit | 520a3c064c42e8cd50ef4fde21539870d5b19cb7 (patch) | |
| tree | 6af62f3ec9273d3e1c0ddd7350f7ae20281fbda7 /source/slang/slang-ir.cpp | |
| parent | 181fd1f3c9c4b047c1947096e7b3f8e5bc2314c3 (diff) | |
Fixed issue with function signatures in higher-order AD (#2814)
Also added GetStringHash to non-differentiable insts
Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source/slang/slang-ir.cpp')
| -rw-r--r-- | source/slang/slang-ir.cpp | 14 |
1 files changed, 12 insertions, 2 deletions
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index dd3034da1..97109274f 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -3877,7 +3877,7 @@ namespace Slang IRInst* IRBuilder::emitDifferentialPairGetDifferential(IRType* diffType, IRInst* diffPair) { - SLANG_ASSERT(as<IRDifferentialPairType>(diffPair->getDataType())); + SLANG_ASSERT(as<IRDifferentialPairTypeBase>(diffPair->getDataType())); return emitIntrinsicInst( diffType, kIROp_DifferentialPairGetDifferential, @@ -3887,7 +3887,7 @@ namespace Slang IRInst* IRBuilder::emitDifferentialPairGetPrimal(IRInst* diffPair) { - auto valueType = cast<IRDifferentialPairType>(diffPair->getDataType())->getValueType(); + auto valueType = cast<IRDifferentialPairTypeBase>(diffPair->getDataType())->getValueType(); return emitIntrinsicInst( valueType, kIROp_DifferentialPairGetPrimal, @@ -3895,6 +3895,15 @@ namespace Slang &diffPair); } + IRInst* IRBuilder::emitDifferentialPairGetPrimal(IRType* primalType, IRInst* diffPair) + { + return emitIntrinsicInst( + primalType, + kIROp_DifferentialPairGetPrimal, + 1, + &diffPair); + } + IRInst* IRBuilder::emitDifferentialPairGetDifferentialUserCode(IRType* diffType, IRInst* diffPair) { SLANG_ASSERT(as<IRDifferentialPairTypeBase>(diffPair->getDataType())); @@ -7256,6 +7265,7 @@ namespace Slang case kIROp_TorchGetCudaStream: case kIROp_MakeTensorView: case kIROp_TorchTensorGetView: + case kIROp_GetStringHash: return false; case kIROp_ForwardDifferentiate: |
