summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-autodiff.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-ir-autodiff.cpp')
-rw-r--r--source/slang/slang-ir-autodiff.cpp51
1 files changed, 35 insertions, 16 deletions
diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp
index 65e880868..edea3847d 100644
--- a/source/slang/slang-ir-autodiff.cpp
+++ b/source/slang/slang-ir-autodiff.cpp
@@ -44,6 +44,18 @@ IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementK
return nullptr;
}
+static IRInst* _getDiffTypeFromPairType(AutoDiffSharedContext*sharedContext, IRBuilder* builder, IRDifferentialPairTypeBase* type)
+{
+ auto witnessTable = type->getWitness();
+ return _lookupWitness(builder, witnessTable, sharedContext->differentialAssocTypeStructKey);
+}
+
+static IRInst* _getDiffTypeWitnessFromPairType(AutoDiffSharedContext* sharedContext, IRBuilder* builder, IRDifferentialPairTypeBase* type)
+{
+ auto witnessTable = type->getWitness();
+ return _lookupWitness(builder, witnessTable, sharedContext->differentialAssocTypeWitnessStructKey);
+}
+
bool isNoDiffType(IRType* paramType)
{
while (auto ptrType = as<IRPtrTypeBase>(paramType))
@@ -266,25 +278,13 @@ IRInst* DifferentialPairTypeBuilder::_createDiffPairType(IRType* origBaseType, I
return pairStructType;
}
-IRInst* DifferentialPairTypeBuilder::getDiffTypeFromPairType(IRBuilder* builder, IRDifferentialPairType* type)
-{
- auto witnessTable = type->getWitness();
- return _lookupWitness(builder, witnessTable, sharedContext->differentialAssocTypeStructKey);
-}
-
-IRInst* DifferentialPairTypeBuilder::getDiffTypeWitnessFromPairType(IRBuilder* builder, IRDifferentialPairType* type)
-{
- auto witnessTable = type->getWitness();
- return _lookupWitness(builder, witnessTable, sharedContext->differentialAssocTypeWitnessStructKey);
-}
-
IRInst* DifferentialPairTypeBuilder::lowerDiffPairType(
IRBuilder* builder, IRType* originalPairType)
{
IRInst* result = nullptr;
if (pairTypeCache.TryGetValue(originalPairType, result))
return result;
- auto pairType = as<IRDifferentialPairType>(originalPairType);
+ auto pairType = as<IRDifferentialPairTypeBase>(originalPairType);
if (!pairType)
{
result = originalPairType;
@@ -297,7 +297,7 @@ IRInst* DifferentialPairTypeBuilder::lowerDiffPairType(
return result;
}
- auto diffType = getDiffTypeFromPairType(builder, pairType);
+ auto diffType = _getDiffTypeFromPairType(sharedContext, builder, pairType);
if (!diffType)
return result;
result = _createDiffPairType(pairType->getValueType(), (IRType*)diffType);
@@ -406,18 +406,28 @@ IRInst* DifferentiableTypeConformanceContext::lookUpInterfaceMethod(IRBuilder* b
}
IRInst* DifferentiableTypeConformanceContext::getDifferentialTypeFromDiffPairType(
- IRBuilder* builder, IRDifferentialPairType* diffPairType)
+ IRBuilder* builder, IRDifferentialPairTypeBase* diffPairType)
{
auto witness = diffPairType->getWitness();
SLANG_RELEASE_ASSERT(witness);
return _lookupWitness(builder, witness, sharedContext->differentialAssocTypeStructKey);
}
+IRInst* DifferentiableTypeConformanceContext::getDiffTypeFromPairType(IRBuilder* builder, IRDifferentialPairTypeBase* type)
+{
+ return _getDiffTypeFromPairType(sharedContext, builder, type);
+}
+
+IRInst* DifferentiableTypeConformanceContext::getDiffTypeWitnessFromPairType(IRBuilder* builder, IRDifferentialPairTypeBase* type)
+{
+ return _getDiffTypeWitnessFromPairType(sharedContext, builder, type);
+}
+
void DifferentiableTypeConformanceContext::buildGlobalWitnessDictionary()
{
for (auto globalInst : sharedContext->moduleInst->getChildren())
{
- if (auto pairType = as<IRDifferentialPairType>(globalInst))
+ if (auto pairType = as<IRDifferentialPairTypeBase>(globalInst))
{
differentiableWitnessDictionary.AddIfNotExists(pairType->getValueType(), pairType->getWitness());
}
@@ -505,6 +515,7 @@ void stripTempDecorations(IRInst* inst)
case kIROp_AutoDiffOriginalValueDecoration:
case kIROp_BackwardDerivativePrimalReturnDecoration:
case kIROp_PrimalValueStructKeyDecoration:
+ case kIROp_PrimalElementTypeDecoration:
decor->removeAndDeallocate();
break;
default:
@@ -578,6 +589,7 @@ bool canTypeBeStored(IRInst* type)
case kIROp_TupleType:
case kIROp_ArrayType:
case kIROp_DifferentialPairType:
+ case kIROp_DifferentialPairUserCodeType:
case kIROp_InterfaceType:
case kIROp_AnyValueType:
case kIROp_ClassType:
@@ -832,6 +844,13 @@ struct AutoDiffPass : public InstPassBase
if (!changed)
break;
+
+ // We have done transcribing the functions, now it is time to demote all DifferentialPair types
+ // and their operations down to DifferentialPairUserCodeType and *UserCode operations so they
+ // can be treated just like normal types with no special semantics in future processing, and won't
+ // be confused with the semantics of a DifferentialPair type during future autodiff code gen.
+ rewriteDifferentialPairToUserCode(module);
+
hasChanges |= changed;
}