summaryrefslogtreecommitdiff
path: root/source/slang/slang-ir-autodiff-transcriber-base.cpp
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2024-09-19 03:10:28 -0400
committerGitHub <noreply@github.com>2024-09-19 00:10:28 -0700
commitccc310fa4e8096cda8a6c127aacc1a1fa9d8503a (patch)
tree435e9c462a78fb848ab3b36c23287543d1a859de /source/slang/slang-ir-autodiff-transcriber-base.cpp
parent1781c2969eb65fb7ade01d3f0d7d9b8973bcd4d3 (diff)
Support `IDifferentiablePtrType` (#5031)
* initial diff-ref-type interface * Initial support for `IDifferentiablePtrType` * Fix unused vars * More tests + fix switch case fallthrough. * Update slang-ir-autodiff.cpp * Update diff-ptr-type-loop.slang * Add optimization to allow more complex pair types * Update slang-ir-autodiff-primal-hoist.cpp * Update diff-ptr-type-loop.slang * Update slang-ir-autodiff-primal-hoist.cpp * More fixes to address reviews * Update slang-check-expr.cpp * Optimizations + rename `differentiableRefInterfaceType` -> `differentiablePtrInterfaceType` * Move pair logic to ir-builder, unify the type dictionaries. --------- Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source/slang/slang-ir-autodiff-transcriber-base.cpp')
-rw-r--r--source/slang/slang-ir-autodiff-transcriber-base.cpp173
1 files changed, 143 insertions, 30 deletions
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp
index 1fa76c730..2141837b5 100644
--- a/source/slang/slang-ir-autodiff-transcriber-base.cpp
+++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp
@@ -174,45 +174,54 @@ IRInst* AutoDiffTranscriberBase::maybeCloneForPrimalInst(IRBuilder* builder, IRI
IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementKey);
-IRInst* AutoDiffTranscriberBase::tryGetDifferentiableWitness(IRBuilder* builder, IRInst* originalType)
+IRInst* AutoDiffTranscriberBase::tryGetDifferentiableWitness(IRBuilder* builder, IRInst* originalType, DiffConformanceKind kind)
{
- return differentiableTypeConformanceContext.tryGetDifferentiableWitness(builder, originalType);
+ if (kind == DiffConformanceKind::Any)
+ {
+ if (auto valueWitness = differentiableTypeConformanceContext.tryGetDifferentiableWitness(builder, originalType, DiffConformanceKind::Value))
+ return valueWitness;
+ if (auto ptrWitness = differentiableTypeConformanceContext.tryGetDifferentiableWitness(builder, originalType, DiffConformanceKind::Ptr))
+ return ptrWitness;
+ }
+ else
+ {
+ return differentiableTypeConformanceContext.tryGetDifferentiableWitness(builder, originalType, kind);
+ }
+ return nullptr;
}
IRType* AutoDiffTranscriberBase::getOrCreateDiffPairType(IRBuilder* builder, IRInst* primalType, IRInst* witness)
{
- return builder->getDifferentialPairType(
- (IRType*)primalType,
- witness);
+ auto conformanceType = differentiableTypeConformanceContext.getConformanceTypeFromWitness(witness);
+ if (autoDiffSharedContext->isInterfaceAvailable &&
+ conformanceType == autoDiffSharedContext->differentiableInterfaceType)
+ {
+ return builder->getDifferentialPairType((IRType*)primalType, witness);
+ }
+ else if (autoDiffSharedContext->isPtrInterfaceAvailable &&
+ conformanceType == autoDiffSharedContext->differentiablePtrInterfaceType)
+ {
+ return builder->getDifferentialPtrPairType((IRType*)primalType, witness);
+ }
+ else
+ {
+ SLANG_UNEXPECTED("Unexpected witness type");
+ }
}
IRType* AutoDiffTranscriberBase::getOrCreateDiffPairType(IRBuilder* builder, IRInst* originalType)
{
- auto primalType = lookupPrimalInst(builder, originalType, nullptr);
+ auto primalType = lookupPrimalInst(builder, originalType, originalType);
SLANG_RELEASE_ASSERT(primalType);
IRInst* witness = nullptr;
- if (auto lookup = as<IRLookupWitnessMethod>(primalType))
- {
- if (lookup->getRequirementKey() == autoDiffSharedContext->differentialAssocTypeStructKey)
- {
- witness = builder->emitLookupInterfaceMethodInst(
- lookup->getWitnessTable()->getDataType(),
- lookup->getWitnessTable(),
- autoDiffSharedContext->differentialAssocTypeWitnessStructKey);
- }
- }
-
- // Obtain the witness that primalType conforms to IDifferentiable.
+
+ // Obtain the witness that primalType conforms to IDifferentiable/IDifferentiablePtrType
if (!witness)
- witness = tryGetDifferentiableWitness(builder, originalType);
+ witness = tryGetDifferentiableWitness(builder, primalType, DiffConformanceKind::Any);
SLANG_RELEASE_ASSERT(witness);
- auto pairType = builder->getDifferentialPairType(
- (IRType*)primalType,
- witness);
-
- return pairType;
+ return getOrCreateDiffPairType(builder, primalType, witness);
}
IRType* AutoDiffTranscriberBase::differentiateType(IRBuilder* builder, IRType* origType)
@@ -223,8 +232,10 @@ IRType* AutoDiffTranscriberBase::differentiateType(IRBuilder* builder, IRType* o
// Special-case for differentiable existential types.
if (as<IRInterfaceType>(origType) || as<IRAssociatedType>(origType))
{
- if (differentiableTypeConformanceContext.lookUpConformanceForType(origType))
+ if (differentiableTypeConformanceContext.lookUpConformanceForType(origType, DiffConformanceKind::Value))
return autoDiffSharedContext->differentiableInterfaceType;
+ else if (differentiableTypeConformanceContext.lookUpConformanceForType(origType, DiffConformanceKind::Ptr))
+ return autoDiffSharedContext->differentiablePtrInterfaceType;
else
return nullptr;
}
@@ -278,8 +289,9 @@ IRType* AutoDiffTranscriberBase::_differentiateTypeImpl(IRBuilder* builder, IRTy
}
case kIROp_DifferentialPairType:
+ case kIROp_DifferentialPtrPairType:
{
- auto primalPairType = as<IRDifferentialPairType>(primalType);
+ auto primalPairType = as<IRDifferentialPairTypeBase>(primalType);
return getOrCreateDiffPairType(
builder,
differentiableTypeConformanceContext.getDiffTypeFromPairType(builder, primalPairType),
@@ -445,8 +457,24 @@ IRType* AutoDiffTranscriberBase::differentiateExtractExistentialType(IRBuilder*
auto interfaceType = as<IRInterfaceType>(unwrapAttributedType(origType->getOperand(0)->getDataType()));
if (!interfaceType)
return nullptr;
- List<IRInterfaceRequirementEntry*> lookupKeyPath = differentiableTypeConformanceContext.findDifferentiableInterfaceLookupPath(
+
+ List<IRInterfaceRequirementEntry*> lookupKeyPath;
+ IRStructKey* diffStructKey = nullptr;
+
+ List<IRInterfaceRequirementEntry*> lookupPathValueType = differentiableTypeConformanceContext.findInterfaceLookupPath(
autoDiffSharedContext->differentiableInterfaceType, interfaceType);
+ if (lookupPathValueType.getCount() > 0)
+ {
+ lookupKeyPath = lookupPathValueType;
+ diffStructKey = autoDiffSharedContext->differentialAssocTypeStructKey;
+ }
+ else
+ {
+ // Try IDifferentiablePtrType
+ lookupKeyPath = differentiableTypeConformanceContext.findInterfaceLookupPath(
+ autoDiffSharedContext->differentiablePtrInterfaceType, interfaceType);
+ diffStructKey = autoDiffSharedContext->differentialAssocRefTypeStructKey;
+ }
if (lookupKeyPath.getCount())
{
@@ -456,7 +484,7 @@ IRType* AutoDiffTranscriberBase::differentiateExtractExistentialType(IRBuilder*
{
outWitnessTable = builder->emitLookupInterfaceMethodInst((IRType*)node->getRequirementVal(), outWitnessTable, node->getRequirementKey());
}
- auto diffType = builder->emitLookupInterfaceMethodInst(builder->getTypeType(), outWitnessTable, autoDiffSharedContext->differentialAssocTypeStructKey);
+ auto diffType = builder->emitLookupInterfaceMethodInst(builder->getTypeType(), outWitnessTable, diffStructKey);
return (IRType*)diffType;
}
return nullptr;
@@ -561,10 +589,31 @@ InstPair AutoDiffTranscriberBase::transcribeLookupInterfaceMethod(IRBuilder* bui
return InstPair(primal, diffWitness);
}
+ else if (returnWitnessType->getConformanceType() == autoDiffSharedContext->differentiablePtrInterfaceType)
+ {
+ auto primalDiffType = builder->emitLookupInterfaceMethodInst(
+ builder->getTypeKind(),
+ primal,
+ autoDiffSharedContext->differentialAssocRefTypeStructKey);
+ auto diffWitness = builder->emitLookupInterfaceMethodInst(
+ (IRType*)primalDiffType,
+ primal,
+ autoDiffSharedContext->differentialAssocRefTypeWitnessStructKey);
+
+ // Mark both as primal since we're working with types
+ // (which don't need transposing)
+ //
+ builder->markInstAsPrimal(primalDiffType);
+ builder->markInstAsPrimal(diffWitness);
+
+ return InstPair(primal, diffWitness);
+ }
}
+
auto decor =
lookupInst->getRequirementKey()->findDecorationImpl(
getInterfaceRequirementDerivativeDecorationOp());
+
if (!decor)
{
return InstPair(primal, nullptr);
@@ -589,6 +638,10 @@ IRInst* AutoDiffTranscriberBase::getDifferentialZeroOfType(
{
originalType = (IRType*)unwrapAttributedType(originalType);
auto primalType = (IRType*)lookupPrimalInst(builder, originalType);
+
+ // Can't generate zero for differentiable ptr types. Should never hit this case.
+ SLANG_ASSERT(!differentiableTypeConformanceContext.isDifferentiablePtrType(originalType));
+
if (auto diffType = differentiateType(builder, originalType))
{
IRInst* diffWitnessTable = nullptr;
@@ -985,7 +1038,7 @@ IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst
&& !as<IRConstant>(pair.differential))
{
auto primalType = (IRType*)(pair.primal->getDataType());
- builder->markInstAsDifferential(pair.differential, primalType);
+ markDiffTypeInst(builder, pair.differential, primalType);
}
}
else
@@ -997,7 +1050,7 @@ IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst
if (as<IRType>(pair.differential))
break;
auto mixedType = (IRType*)(pair.primal->getDataType());
- builder->markInstAsMixedDifferential(pair.primal, mixedType);
+ markDiffPairTypeInst(builder, pair.primal, mixedType);
}
}
@@ -1076,4 +1129,64 @@ InstPair AutoDiffTranscriberBase::transcribeInst(IRBuilder* builder, IRInst* ori
return result;
}
+
+void AutoDiffTranscriberBase::markDiffTypeInst(IRBuilder* builder, IRInst* diffInst, IRType* primalType)
+{
+ // Ignore module-level insts.
+ if (as<IRModuleInst>(diffInst->getParent()))
+ return;
+
+ // Also ignore generic-container-level insts.
+ if (as<IRBlock>(diffInst->getParent()) &&
+ as<IRGeneric>(diffInst->getParent()->getParent()))
+ return;
+
+ // TODO: This logic is a bit of a hack. We need to determine if the type is
+ // relevant to ptr-type computation or not, or more complex applications
+ // that use dynamic dispatch + ptr types will fail.
+ //
+ if (as<IRType>(diffInst))
+ {
+ builder->markInstAsDifferential(diffInst, nullptr);
+ return;
+ }
+
+ SLANG_ASSERT(diffInst);
+ SLANG_ASSERT(primalType);
+
+ if (differentiableTypeConformanceContext.isDifferentiableValueType(primalType))
+ {
+ builder->markInstAsDifferential(diffInst, primalType);
+ }
+ else if (differentiableTypeConformanceContext.isDifferentiablePtrType(primalType))
+ {
+ builder->markInstAsPrimal(diffInst);
+ }
+ else
+ {
+ // Stop-gap solution to go with differential inst for now.
+ builder->markInstAsDifferential(diffInst, primalType);
+ }
+}
+
+void AutoDiffTranscriberBase::markDiffPairTypeInst(IRBuilder* builder, IRInst* diffPairInst, IRType* pairType)
+{
+ SLANG_ASSERT(diffPairInst);
+ SLANG_ASSERT(pairType);
+ SLANG_ASSERT(as<IRDifferentialPairTypeBase>(pairType));
+
+ if (as<IRDifferentialPairType>(pairType))
+ {
+ builder->markInstAsMixedDifferential(diffPairInst, pairType);
+ }
+ else if (as<IRDifferentialPtrPairType>(pairType))
+ {
+ builder->markInstAsPrimal(diffPairInst);
+ }
+ else
+ {
+ SLANG_UNEXPECTED("unexpected differentiable type");
+ }
+}
+
}