summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp3
-rw-r--r--source/slang/slang-ir-insts.h1
-rw-r--r--source/slang/slang-ir.cpp14
3 files changed, 15 insertions, 3 deletions
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index 2be394537..5cf3c1509 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -663,7 +663,7 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig
// Read back new value.
auto newVal = afterBuilder.emitLoad(srcVar);
afterBuilder.markInstAsMixedDifferential(newVal, pairValType->getValueType());
- auto newPrimalVal = afterBuilder.emitDifferentialPairGetPrimal(newVal);
+ auto newPrimalVal = afterBuilder.emitDifferentialPairGetPrimal(pairValType->getValueType(), newVal);
afterBuilder.emitStore(primalArg, newPrimalVal);
if (diffArg)
@@ -1798,6 +1798,7 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst*
case kIROp_CastFloatToInt:
case kIROp_DetachDerivative:
case kIROp_GetSequentialID:
+ case kIROp_GetStringHash:
return trascribeNonDiffInst(builder, origInst);
// A call to createDynamicObject<T>(arbitraryData) cannot provide a diff value,
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index fe658566c..356ccf4d6 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -3090,6 +3090,7 @@ public:
IRInst* emitMakeOptionalNone(IRInst* optType, IRInst* defaultValue);
IRInst* emitDifferentialPairGetDifferential(IRType* diffType, IRInst* diffPair);
IRInst* emitDifferentialPairGetPrimal(IRInst* diffPair);
+ IRInst* emitDifferentialPairGetPrimal(IRType* primalType, IRInst* diffPair);
IRInst* emitDifferentialPairGetDifferentialUserCode(IRType* diffType, IRInst* diffPair);
IRInst* emitDifferentialPairGetPrimalUserCode(IRInst* diffPair);
IRInst* emitMakeVector(
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index dd3034da1..97109274f 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -3877,7 +3877,7 @@ namespace Slang
IRInst* IRBuilder::emitDifferentialPairGetDifferential(IRType* diffType, IRInst* diffPair)
{
- SLANG_ASSERT(as<IRDifferentialPairType>(diffPair->getDataType()));
+ SLANG_ASSERT(as<IRDifferentialPairTypeBase>(diffPair->getDataType()));
return emitIntrinsicInst(
diffType,
kIROp_DifferentialPairGetDifferential,
@@ -3887,7 +3887,7 @@ namespace Slang
IRInst* IRBuilder::emitDifferentialPairGetPrimal(IRInst* diffPair)
{
- auto valueType = cast<IRDifferentialPairType>(diffPair->getDataType())->getValueType();
+ auto valueType = cast<IRDifferentialPairTypeBase>(diffPair->getDataType())->getValueType();
return emitIntrinsicInst(
valueType,
kIROp_DifferentialPairGetPrimal,
@@ -3895,6 +3895,15 @@ namespace Slang
&diffPair);
}
+ IRInst* IRBuilder::emitDifferentialPairGetPrimal(IRType* primalType, IRInst* diffPair)
+ {
+ return emitIntrinsicInst(
+ primalType,
+ kIROp_DifferentialPairGetPrimal,
+ 1,
+ &diffPair);
+ }
+
IRInst* IRBuilder::emitDifferentialPairGetDifferentialUserCode(IRType* diffType, IRInst* diffPair)
{
SLANG_ASSERT(as<IRDifferentialPairTypeBase>(diffPair->getDataType()));
@@ -7256,6 +7265,7 @@ namespace Slang
case kIROp_TorchGetCudaStream:
case kIROp_MakeTensorView:
case kIROp_TorchTensorGetView:
+ case kIROp_GetStringHash:
return false;
case kIROp_ForwardDifferentiate: