diff options
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.cpp | 4 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-unzip.cpp | 4 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff.cpp | 1 |
3 files changed, 5 insertions, 4 deletions
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index 5cf3c1509..6025e1ccd 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -628,7 +628,7 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig while (auto attrType = as<IRAttributedType>(origType)) origType = attrType->getBaseType(); } - if (auto pairType = tryGetDiffPairType(&argBuilder, origType)) + if (auto pairType = tryGetDiffPairType(&argBuilder, primalType)) { auto pairPtrType = as<IRPtrTypeBase>(pairType); auto pairValType = as<IRDifferentialPairType>( @@ -637,7 +637,7 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig if (auto ptrParamType = as<IRPtrTypeBase>(diffParamType)) { // Create temp var to pass in/out arguments. - auto srcVar = argBuilder.emitVar(ptrParamType->getValueType()); + auto srcVar = argBuilder.emitVar(pairValType); argBuilder.markInstAsMixedDifferential(srcVar, pairValType->getValueType()); auto diffArg = findOrTranscribeDiffInst(&argBuilder, origArg); diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp index 53f0cbba2..c3ce32540 100644 --- a/source/slang/slang-ir-autodiff-unzip.cpp +++ b/source/slang/slang-ir-autodiff-unzip.cpp @@ -85,7 +85,7 @@ struct ExtractPrimalFuncContext List<IRType*> paramTypes; for (Index i = 0; i < ((Count) originalFuncType->getParamCount()) - 1; i++) paramTypes.add((IRType*)migrationContext.cloneInst(&builder, originalFuncType->getParamType(i))); - paramTypes.add(builder.getInOutType((IRType*)outIntermediateType)); + paramTypes.add(builder.getOutType((IRType*)outIntermediateType)); auto resultType = (IRType*)migrationContext.cloneInst(&builder, originalFuncType->getResultType()); auto newFuncType = builder.getFuncType(paramTypes, resultType); return newFuncType; @@ -183,7 +183,7 @@ struct ExtractPrimalFuncContext builder.setInsertInto(paramBlock); auto oldIntermediateParam = func->getLastParam(); auto outIntermediary = - builder.emitParam(builder.getInOutType((IRType*)intermediateType)); + builder.emitParam(builder.getOutType((IRType*)intermediateType)); oldIntermediateParam->transferDecorationsTo(outIntermediary); primalParams.Add(outIntermediary); oldIntermediateParam->replaceUsesWith(outIntermediary); diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index 3257ee102..9a7a42619 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -896,6 +896,7 @@ bool canTypeBeStored(IRInst* type) case kIROp_DifferentialPairType: case kIROp_DifferentialPairUserCodeType: case kIROp_InterfaceType: + case kIROp_AssociatedType: case kIROp_AnyValueType: case kIROp_ClassType: case kIROp_FloatType: |
