diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2023-04-19 12:49:14 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-04-19 09:49:14 -0700 |
| commit | 520a3c064c42e8cd50ef4fde21539870d5b19cb7 (patch) | |
| tree | 6af62f3ec9273d3e1c0ddd7350f7ae20281fbda7 | |
| parent | 181fd1f3c9c4b047c1947096e7b3f8e5bc2314c3 (diff) | |
Fixed issue with function signatures in higher-order AD (#2814)
Also added GetStringHash to non-differentiable insts
Co-authored-by: Yong He <yonghe@outlook.com>
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.cpp | 3 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 1 | ||||
| -rw-r--r-- | source/slang/slang-ir.cpp | 14 |
3 files changed, 15 insertions, 3 deletions
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index 2be394537..5cf3c1509 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -663,7 +663,7 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig // Read back new value. auto newVal = afterBuilder.emitLoad(srcVar); afterBuilder.markInstAsMixedDifferential(newVal, pairValType->getValueType()); - auto newPrimalVal = afterBuilder.emitDifferentialPairGetPrimal(newVal); + auto newPrimalVal = afterBuilder.emitDifferentialPairGetPrimal(pairValType->getValueType(), newVal); afterBuilder.emitStore(primalArg, newPrimalVal); if (diffArg) @@ -1798,6 +1798,7 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* case kIROp_CastFloatToInt: case kIROp_DetachDerivative: case kIROp_GetSequentialID: + case kIROp_GetStringHash: return trascribeNonDiffInst(builder, origInst); // A call to createDynamicObject<T>(arbitraryData) cannot provide a diff value, diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index fe658566c..356ccf4d6 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -3090,6 +3090,7 @@ public: IRInst* emitMakeOptionalNone(IRInst* optType, IRInst* defaultValue); IRInst* emitDifferentialPairGetDifferential(IRType* diffType, IRInst* diffPair); IRInst* emitDifferentialPairGetPrimal(IRInst* diffPair); + IRInst* emitDifferentialPairGetPrimal(IRType* primalType, IRInst* diffPair); IRInst* emitDifferentialPairGetDifferentialUserCode(IRType* diffType, IRInst* diffPair); IRInst* emitDifferentialPairGetPrimalUserCode(IRInst* diffPair); IRInst* emitMakeVector( diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index dd3034da1..97109274f 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -3877,7 +3877,7 @@ namespace Slang IRInst* IRBuilder::emitDifferentialPairGetDifferential(IRType* diffType, IRInst* diffPair) { - SLANG_ASSERT(as<IRDifferentialPairType>(diffPair->getDataType())); + SLANG_ASSERT(as<IRDifferentialPairTypeBase>(diffPair->getDataType())); return emitIntrinsicInst( diffType, kIROp_DifferentialPairGetDifferential, @@ -3887,7 +3887,7 @@ namespace Slang IRInst* IRBuilder::emitDifferentialPairGetPrimal(IRInst* diffPair) { - auto valueType = cast<IRDifferentialPairType>(diffPair->getDataType())->getValueType(); + auto valueType = cast<IRDifferentialPairTypeBase>(diffPair->getDataType())->getValueType(); return emitIntrinsicInst( valueType, kIROp_DifferentialPairGetPrimal, @@ -3895,6 +3895,15 @@ namespace Slang &diffPair); } + IRInst* IRBuilder::emitDifferentialPairGetPrimal(IRType* primalType, IRInst* diffPair) + { + return emitIntrinsicInst( + primalType, + kIROp_DifferentialPairGetPrimal, + 1, + &diffPair); + } + IRInst* IRBuilder::emitDifferentialPairGetDifferentialUserCode(IRType* diffType, IRInst* diffPair) { SLANG_ASSERT(as<IRDifferentialPairTypeBase>(diffPair->getDataType())); @@ -7256,6 +7265,7 @@ namespace Slang case kIROp_TorchGetCudaStream: case kIROp_MakeTensorView: case kIROp_TorchTensorGetView: + case kIROp_GetStringHash: return false; case kIROp_ForwardDifferentiate: |
