summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-ir.cpp
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2023-04-19 12:49:14 -0400
committerGitHub <noreply@github.com>2023-04-19 09:49:14 -0700
commit520a3c064c42e8cd50ef4fde21539870d5b19cb7 (patch)
tree6af62f3ec9273d3e1c0ddd7350f7ae20281fbda7 /source/slang/slang-ir.cpp
parent181fd1f3c9c4b047c1947096e7b3f8e5bc2314c3 (diff)
Fixed issue with function signatures in higher-order AD (#2814)
Also added GetStringHash to non-differentiable insts Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source/slang/slang-ir.cpp')
-rw-r--r--source/slang/slang-ir.cpp14
1 files changed, 12 insertions, 2 deletions
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index dd3034da1..97109274f 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -3877,7 +3877,7 @@ namespace Slang
IRInst* IRBuilder::emitDifferentialPairGetDifferential(IRType* diffType, IRInst* diffPair)
{
- SLANG_ASSERT(as<IRDifferentialPairType>(diffPair->getDataType()));
+ SLANG_ASSERT(as<IRDifferentialPairTypeBase>(diffPair->getDataType()));
return emitIntrinsicInst(
diffType,
kIROp_DifferentialPairGetDifferential,
@@ -3887,7 +3887,7 @@ namespace Slang
IRInst* IRBuilder::emitDifferentialPairGetPrimal(IRInst* diffPair)
{
- auto valueType = cast<IRDifferentialPairType>(diffPair->getDataType())->getValueType();
+ auto valueType = cast<IRDifferentialPairTypeBase>(diffPair->getDataType())->getValueType();
return emitIntrinsicInst(
valueType,
kIROp_DifferentialPairGetPrimal,
@@ -3895,6 +3895,15 @@ namespace Slang
&diffPair);
}
+ IRInst* IRBuilder::emitDifferentialPairGetPrimal(IRType* primalType, IRInst* diffPair)
+ {
+ return emitIntrinsicInst(
+ primalType,
+ kIROp_DifferentialPairGetPrimal,
+ 1,
+ &diffPair);
+ }
+
IRInst* IRBuilder::emitDifferentialPairGetDifferentialUserCode(IRType* diffType, IRInst* diffPair)
{
SLANG_ASSERT(as<IRDifferentialPairTypeBase>(diffPair->getDataType()));
@@ -7256,6 +7265,7 @@ namespace Slang
case kIROp_TorchGetCudaStream:
case kIROp_MakeTensorView:
case kIROp_TorchTensorGetView:
+ case kIROp_GetStringHash:
return false;
case kIROp_ForwardDifferentiate: