summaryrefslogtreecommitdiff
path: root/source
diff options
context:
space:
mode:
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp4
-rw-r--r--source/slang/slang-ir-autodiff-unzip.cpp4
-rw-r--r--source/slang/slang-ir-autodiff.cpp1
3 files changed, 5 insertions, 4 deletions
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index 5cf3c1509..6025e1ccd 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -628,7 +628,7 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig
while (auto attrType = as<IRAttributedType>(origType))
origType = attrType->getBaseType();
}
- if (auto pairType = tryGetDiffPairType(&argBuilder, origType))
+ if (auto pairType = tryGetDiffPairType(&argBuilder, primalType))
{
auto pairPtrType = as<IRPtrTypeBase>(pairType);
auto pairValType = as<IRDifferentialPairType>(
@@ -637,7 +637,7 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig
if (auto ptrParamType = as<IRPtrTypeBase>(diffParamType))
{
// Create temp var to pass in/out arguments.
- auto srcVar = argBuilder.emitVar(ptrParamType->getValueType());
+ auto srcVar = argBuilder.emitVar(pairValType);
argBuilder.markInstAsMixedDifferential(srcVar, pairValType->getValueType());
auto diffArg = findOrTranscribeDiffInst(&argBuilder, origArg);
diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp
index 53f0cbba2..c3ce32540 100644
--- a/source/slang/slang-ir-autodiff-unzip.cpp
+++ b/source/slang/slang-ir-autodiff-unzip.cpp
@@ -85,7 +85,7 @@ struct ExtractPrimalFuncContext
List<IRType*> paramTypes;
for (Index i = 0; i < ((Count) originalFuncType->getParamCount()) - 1; i++)
paramTypes.add((IRType*)migrationContext.cloneInst(&builder, originalFuncType->getParamType(i)));
- paramTypes.add(builder.getInOutType((IRType*)outIntermediateType));
+ paramTypes.add(builder.getOutType((IRType*)outIntermediateType));
auto resultType = (IRType*)migrationContext.cloneInst(&builder, originalFuncType->getResultType());
auto newFuncType = builder.getFuncType(paramTypes, resultType);
return newFuncType;
@@ -183,7 +183,7 @@ struct ExtractPrimalFuncContext
builder.setInsertInto(paramBlock);
auto oldIntermediateParam = func->getLastParam();
auto outIntermediary =
- builder.emitParam(builder.getInOutType((IRType*)intermediateType));
+ builder.emitParam(builder.getOutType((IRType*)intermediateType));
oldIntermediateParam->transferDecorationsTo(outIntermediary);
primalParams.Add(outIntermediary);
oldIntermediateParam->replaceUsesWith(outIntermediary);
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index 3257ee102..9a7a42619 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -896,6 +896,7 @@ bool canTypeBeStored(IRInst* type)
case kIROp_DifferentialPairType:
case kIROp_DifferentialPairUserCodeType:
case kIROp_InterfaceType:
+ case kIROp_AssociatedType:
case kIROp_AnyValueType:
case kIROp_ClassType:
case kIROp_FloatType: