From 3406f27d90a248194991b46d3f5fd89a1fd38b11 Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Fri, 21 Apr 2023 10:41:24 -0400 Subject: AD: Various fixes around dynamic dispatch (#2820) * Add a test for the new diff material system * Various fixes for AD - inout primal context params converted to out params, - added attributed types to list of stored types - used differentiated primal func type instead of type of differentiated func to avoid tangling with user-code differential types. --------- Co-authored-by: Lifan Wu --- source/slang/slang-ir-autodiff-fwd.cpp | 4 ++-- source/slang/slang-ir-autodiff-unzip.cpp | 4 ++-- source/slang/slang-ir-autodiff.cpp | 1 + 3 files changed, 5 insertions(+), 4 deletions(-) (limited to 'source') diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index 5cf3c1509..6025e1ccd 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -628,7 +628,7 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig while (auto attrType = as(origType)) origType = attrType->getBaseType(); } - if (auto pairType = tryGetDiffPairType(&argBuilder, origType)) + if (auto pairType = tryGetDiffPairType(&argBuilder, primalType)) { auto pairPtrType = as(pairType); auto pairValType = as( @@ -637,7 +637,7 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig if (auto ptrParamType = as(diffParamType)) { // Create temp var to pass in/out arguments. - auto srcVar = argBuilder.emitVar(ptrParamType->getValueType()); + auto srcVar = argBuilder.emitVar(pairValType); argBuilder.markInstAsMixedDifferential(srcVar, pairValType->getValueType()); auto diffArg = findOrTranscribeDiffInst(&argBuilder, origArg); diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp index 53f0cbba2..c3ce32540 100644 --- a/source/slang/slang-ir-autodiff-unzip.cpp +++ b/source/slang/slang-ir-autodiff-unzip.cpp @@ -85,7 +85,7 @@ struct ExtractPrimalFuncContext List paramTypes; for (Index i = 0; i < ((Count) originalFuncType->getParamCount()) - 1; i++) paramTypes.add((IRType*)migrationContext.cloneInst(&builder, originalFuncType->getParamType(i))); - paramTypes.add(builder.getInOutType((IRType*)outIntermediateType)); + paramTypes.add(builder.getOutType((IRType*)outIntermediateType)); auto resultType = (IRType*)migrationContext.cloneInst(&builder, originalFuncType->getResultType()); auto newFuncType = builder.getFuncType(paramTypes, resultType); return newFuncType; @@ -183,7 +183,7 @@ struct ExtractPrimalFuncContext builder.setInsertInto(paramBlock); auto oldIntermediateParam = func->getLastParam(); auto outIntermediary = - builder.emitParam(builder.getInOutType((IRType*)intermediateType)); + builder.emitParam(builder.getOutType((IRType*)intermediateType)); oldIntermediateParam->transferDecorationsTo(outIntermediary); primalParams.Add(outIntermediary); oldIntermediateParam->replaceUsesWith(outIntermediary); diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index 3257ee102..9a7a42619 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -896,6 +896,7 @@ bool canTypeBeStored(IRInst* type) case kIROp_DifferentialPairType: case kIROp_DifferentialPairUserCodeType: case kIROp_InterfaceType: + case kIROp_AssociatedType: case kIROp_AnyValueType: case kIROp_ClassType: case kIROp_FloatType: -- cgit v1.2.3