summaryrefslogtreecommitdiff
path: root/source
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2023-04-21 10:41:24 -0400
committerGitHub <noreply@github.com>2023-04-21 10:41:24 -0400
commit3406f27d90a248194991b46d3f5fd89a1fd38b11 (patch)
treeef01e52d656de622be65a9d98c8554af9b91b766 /source
parentcc948557ab305ae0d0c3fe7c8915ab32000bd09e (diff)
AD: Various fixes around dynamic dispatch (#2820)
* Add a test for the new diff material system * Various fixes for AD - inout primal context params converted to out params, - added attributed types to list of stored types - used differentiated primal func type instead of type of differentiated func to avoid tangling with user-code differential types. --------- Co-authored-by: Lifan Wu <lifanw@nvidia.com>
Diffstat (limited to 'source')
-rw-r--r--source/slang/slang-ir-autodiff-fwd.cpp4
-rw-r--r--source/slang/slang-ir-autodiff-unzip.cpp4
-rw-r--r--source/slang/slang-ir-autodiff.cpp1
3 files changed, 5 insertions, 4 deletions
diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp
index 5cf3c1509..6025e1ccd 100644
--- a/source/slang/slang-ir-autodiff-fwd.cpp
+++ b/source/slang/slang-ir-autodiff-fwd.cpp
@@ -628,7 +628,7 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig
while (auto attrType = as<IRAttributedType>(origType))
origType = attrType->getBaseType();
}
- if (auto pairType = tryGetDiffPairType(&argBuilder, origType))
+ if (auto pairType = tryGetDiffPairType(&argBuilder, primalType))
{
auto pairPtrType = as<IRPtrTypeBase>(pairType);
auto pairValType = as<IRDifferentialPairType>(
@@ -637,7 +637,7 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig
if (auto ptrParamType = as<IRPtrTypeBase>(diffParamType))
{
// Create temp var to pass in/out arguments.
- auto srcVar = argBuilder.emitVar(ptrParamType->getValueType());
+ auto srcVar = argBuilder.emitVar(pairValType);
argBuilder.markInstAsMixedDifferential(srcVar, pairValType->getValueType());
auto diffArg = findOrTranscribeDiffInst(&argBuilder, origArg);
diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp
index 53f0cbba2..c3ce32540 100644
--- a/source/slang/slang-ir-autodiff-unzip.cpp
+++ b/source/slang/slang-ir-autodiff-unzip.cpp
@@ -85,7 +85,7 @@ struct ExtractPrimalFuncContext
List<IRType*> paramTypes;
for (Index i = 0; i < ((Count) originalFuncType->getParamCount()) - 1; i++)
paramTypes.add((IRType*)migrationContext.cloneInst(&builder, originalFuncType->getParamType(i)));
- paramTypes.add(builder.getInOutType((IRType*)outIntermediateType));
+ paramTypes.add(builder.getOutType((IRType*)outIntermediateType));
auto resultType = (IRType*)migrationContext.cloneInst(&builder, originalFuncType->getResultType());
auto newFuncType = builder.getFuncType(paramTypes, resultType);
return newFuncType;
@@ -183,7 +183,7 @@ struct ExtractPrimalFuncContext
builder.setInsertInto(paramBlock);
auto oldIntermediateParam = func->getLastParam();
auto outIntermediary =
- builder.emitParam(builder.getInOutType((IRType*)intermediateType));
+ builder.emitParam(builder.getOutType((IRType*)intermediateType));
oldIntermediateParam->transferDecorationsTo(outIntermediary);
primalParams.Add(outIntermediary);
oldIntermediateParam->replaceUsesWith(outIntermediary);
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index 3257ee102..9a7a42619 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -896,6 +896,7 @@ bool canTypeBeStored(IRInst* type)
case kIROp_DifferentialPairType:
case kIROp_DifferentialPairUserCodeType:
case kIROp_InterfaceType:
+ case kIROp_AssociatedType:
case kIROp_AnyValueType:
case kIROp_ClassType:
case kIROp_FloatType: