diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2023-04-21 10:41:24 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-04-21 10:41:24 -0400 |
| commit | 3406f27d90a248194991b46d3f5fd89a1fd38b11 (patch) | |
| tree | ef01e52d656de622be65a9d98c8554af9b91b766 /source | |
| parent | cc948557ab305ae0d0c3fe7c8915ab32000bd09e (diff) | |
AD: Various fixes around dynamic dispatch (#2820)
* Add a test for the new diff material system
* Various fixes for AD
- inout primal context params converted to out params,
- added attributed types to list of stored types
- used differentiated primal func type instead of type of differentiated func to avoid tangling with user-code differential types.
---------
Co-authored-by: Lifan Wu <lifanw@nvidia.com>
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.cpp | 4 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-unzip.cpp | 4 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff.cpp | 1 |
3 files changed, 5 insertions, 4 deletions
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index 5cf3c1509..6025e1ccd 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -628,7 +628,7 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig while (auto attrType = as<IRAttributedType>(origType)) origType = attrType->getBaseType(); } - if (auto pairType = tryGetDiffPairType(&argBuilder, origType)) + if (auto pairType = tryGetDiffPairType(&argBuilder, primalType)) { auto pairPtrType = as<IRPtrTypeBase>(pairType); auto pairValType = as<IRDifferentialPairType>( @@ -637,7 +637,7 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig if (auto ptrParamType = as<IRPtrTypeBase>(diffParamType)) { // Create temp var to pass in/out arguments. - auto srcVar = argBuilder.emitVar(ptrParamType->getValueType()); + auto srcVar = argBuilder.emitVar(pairValType); argBuilder.markInstAsMixedDifferential(srcVar, pairValType->getValueType()); auto diffArg = findOrTranscribeDiffInst(&argBuilder, origArg); diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp index 53f0cbba2..c3ce32540 100644 --- a/source/slang/slang-ir-autodiff-unzip.cpp +++ b/source/slang/slang-ir-autodiff-unzip.cpp @@ -85,7 +85,7 @@ struct ExtractPrimalFuncContext List<IRType*> paramTypes; for (Index i = 0; i < ((Count) originalFuncType->getParamCount()) - 1; i++) paramTypes.add((IRType*)migrationContext.cloneInst(&builder, originalFuncType->getParamType(i))); - paramTypes.add(builder.getInOutType((IRType*)outIntermediateType)); + paramTypes.add(builder.getOutType((IRType*)outIntermediateType)); auto resultType = (IRType*)migrationContext.cloneInst(&builder, originalFuncType->getResultType()); auto newFuncType = builder.getFuncType(paramTypes, resultType); return newFuncType; @@ -183,7 +183,7 @@ struct ExtractPrimalFuncContext builder.setInsertInto(paramBlock); auto oldIntermediateParam = func->getLastParam(); auto outIntermediary = - builder.emitParam(builder.getInOutType((IRType*)intermediateType)); + builder.emitParam(builder.getOutType((IRType*)intermediateType)); oldIntermediateParam->transferDecorationsTo(outIntermediary); primalParams.Add(outIntermediary); oldIntermediateParam->replaceUsesWith(outIntermediary); diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index 3257ee102..9a7a42619 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -896,6 +896,7 @@ bool canTypeBeStored(IRInst* type) case kIROp_DifferentialPairType: case kIROp_DifferentialPairUserCodeType: case kIROp_InterfaceType: + case kIROp_AssociatedType: case kIROp_AnyValueType: case kIROp_ClassType: case kIROp_FloatType: |
