diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2024-09-19 03:10:28 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-09-19 00:10:28 -0700 |
| commit | ccc310fa4e8096cda8a6c127aacc1a1fa9d8503a (patch) | |
| tree | 435e9c462a78fb848ab3b36c23287543d1a859de /source/slang/slang-ir-autodiff-transcriber-base.cpp | |
| parent | 1781c2969eb65fb7ade01d3f0d7d9b8973bcd4d3 (diff) | |
Support `IDifferentiablePtrType` (#5031)
* initial diff-ref-type interface
* Initial support for `IDifferentiablePtrType`
* Fix unused vars
* More tests + fix switch case fallthrough.
* Update slang-ir-autodiff.cpp
* Update diff-ptr-type-loop.slang
* Add optimization to allow more complex pair types
* Update slang-ir-autodiff-primal-hoist.cpp
* Update diff-ptr-type-loop.slang
* Update slang-ir-autodiff-primal-hoist.cpp
* More fixes to address reviews
* Update slang-check-expr.cpp
* Optimizations + rename `differentiableRefInterfaceType` -> `differentiablePtrInterfaceType`
* Move pair logic to ir-builder, unify the type dictionaries.
---------
Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source/slang/slang-ir-autodiff-transcriber-base.cpp')
| -rw-r--r-- | source/slang/slang-ir-autodiff-transcriber-base.cpp | 173 |
1 files changed, 143 insertions, 30 deletions
diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index 1fa76c730..2141837b5 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -174,45 +174,54 @@ IRInst* AutoDiffTranscriberBase::maybeCloneForPrimalInst(IRBuilder* builder, IRI IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementKey); -IRInst* AutoDiffTranscriberBase::tryGetDifferentiableWitness(IRBuilder* builder, IRInst* originalType) +IRInst* AutoDiffTranscriberBase::tryGetDifferentiableWitness(IRBuilder* builder, IRInst* originalType, DiffConformanceKind kind) { - return differentiableTypeConformanceContext.tryGetDifferentiableWitness(builder, originalType); + if (kind == DiffConformanceKind::Any) + { + if (auto valueWitness = differentiableTypeConformanceContext.tryGetDifferentiableWitness(builder, originalType, DiffConformanceKind::Value)) + return valueWitness; + if (auto ptrWitness = differentiableTypeConformanceContext.tryGetDifferentiableWitness(builder, originalType, DiffConformanceKind::Ptr)) + return ptrWitness; + } + else + { + return differentiableTypeConformanceContext.tryGetDifferentiableWitness(builder, originalType, kind); + } + return nullptr; } IRType* AutoDiffTranscriberBase::getOrCreateDiffPairType(IRBuilder* builder, IRInst* primalType, IRInst* witness) { - return builder->getDifferentialPairType( - (IRType*)primalType, - witness); + auto conformanceType = differentiableTypeConformanceContext.getConformanceTypeFromWitness(witness); + if (autoDiffSharedContext->isInterfaceAvailable && + conformanceType == autoDiffSharedContext->differentiableInterfaceType) + { + return builder->getDifferentialPairType((IRType*)primalType, witness); + } + else if (autoDiffSharedContext->isPtrInterfaceAvailable && + conformanceType == autoDiffSharedContext->differentiablePtrInterfaceType) + { + return builder->getDifferentialPtrPairType((IRType*)primalType, witness); + } + else + { + SLANG_UNEXPECTED("Unexpected witness type"); + } } IRType* AutoDiffTranscriberBase::getOrCreateDiffPairType(IRBuilder* builder, IRInst* originalType) { - auto primalType = lookupPrimalInst(builder, originalType, nullptr); + auto primalType = lookupPrimalInst(builder, originalType, originalType); SLANG_RELEASE_ASSERT(primalType); IRInst* witness = nullptr; - if (auto lookup = as<IRLookupWitnessMethod>(primalType)) - { - if (lookup->getRequirementKey() == autoDiffSharedContext->differentialAssocTypeStructKey) - { - witness = builder->emitLookupInterfaceMethodInst( - lookup->getWitnessTable()->getDataType(), - lookup->getWitnessTable(), - autoDiffSharedContext->differentialAssocTypeWitnessStructKey); - } - } - - // Obtain the witness that primalType conforms to IDifferentiable. + + // Obtain the witness that primalType conforms to IDifferentiable/IDifferentiablePtrType if (!witness) - witness = tryGetDifferentiableWitness(builder, originalType); + witness = tryGetDifferentiableWitness(builder, primalType, DiffConformanceKind::Any); SLANG_RELEASE_ASSERT(witness); - auto pairType = builder->getDifferentialPairType( - (IRType*)primalType, - witness); - - return pairType; + return getOrCreateDiffPairType(builder, primalType, witness); } IRType* AutoDiffTranscriberBase::differentiateType(IRBuilder* builder, IRType* origType) @@ -223,8 +232,10 @@ IRType* AutoDiffTranscriberBase::differentiateType(IRBuilder* builder, IRType* o // Special-case for differentiable existential types. if (as<IRInterfaceType>(origType) || as<IRAssociatedType>(origType)) { - if (differentiableTypeConformanceContext.lookUpConformanceForType(origType)) + if (differentiableTypeConformanceContext.lookUpConformanceForType(origType, DiffConformanceKind::Value)) return autoDiffSharedContext->differentiableInterfaceType; + else if (differentiableTypeConformanceContext.lookUpConformanceForType(origType, DiffConformanceKind::Ptr)) + return autoDiffSharedContext->differentiablePtrInterfaceType; else return nullptr; } @@ -278,8 +289,9 @@ IRType* AutoDiffTranscriberBase::_differentiateTypeImpl(IRBuilder* builder, IRTy } case kIROp_DifferentialPairType: + case kIROp_DifferentialPtrPairType: { - auto primalPairType = as<IRDifferentialPairType>(primalType); + auto primalPairType = as<IRDifferentialPairTypeBase>(primalType); return getOrCreateDiffPairType( builder, differentiableTypeConformanceContext.getDiffTypeFromPairType(builder, primalPairType), @@ -445,8 +457,24 @@ IRType* AutoDiffTranscriberBase::differentiateExtractExistentialType(IRBuilder* auto interfaceType = as<IRInterfaceType>(unwrapAttributedType(origType->getOperand(0)->getDataType())); if (!interfaceType) return nullptr; - List<IRInterfaceRequirementEntry*> lookupKeyPath = differentiableTypeConformanceContext.findDifferentiableInterfaceLookupPath( + + List<IRInterfaceRequirementEntry*> lookupKeyPath; + IRStructKey* diffStructKey = nullptr; + + List<IRInterfaceRequirementEntry*> lookupPathValueType = differentiableTypeConformanceContext.findInterfaceLookupPath( autoDiffSharedContext->differentiableInterfaceType, interfaceType); + if (lookupPathValueType.getCount() > 0) + { + lookupKeyPath = lookupPathValueType; + diffStructKey = autoDiffSharedContext->differentialAssocTypeStructKey; + } + else + { + // Try IDifferentiablePtrType + lookupKeyPath = differentiableTypeConformanceContext.findInterfaceLookupPath( + autoDiffSharedContext->differentiablePtrInterfaceType, interfaceType); + diffStructKey = autoDiffSharedContext->differentialAssocRefTypeStructKey; + } if (lookupKeyPath.getCount()) { @@ -456,7 +484,7 @@ IRType* AutoDiffTranscriberBase::differentiateExtractExistentialType(IRBuilder* { outWitnessTable = builder->emitLookupInterfaceMethodInst((IRType*)node->getRequirementVal(), outWitnessTable, node->getRequirementKey()); } - auto diffType = builder->emitLookupInterfaceMethodInst(builder->getTypeType(), outWitnessTable, autoDiffSharedContext->differentialAssocTypeStructKey); + auto diffType = builder->emitLookupInterfaceMethodInst(builder->getTypeType(), outWitnessTable, diffStructKey); return (IRType*)diffType; } return nullptr; @@ -561,10 +589,31 @@ InstPair AutoDiffTranscriberBase::transcribeLookupInterfaceMethod(IRBuilder* bui return InstPair(primal, diffWitness); } + else if (returnWitnessType->getConformanceType() == autoDiffSharedContext->differentiablePtrInterfaceType) + { + auto primalDiffType = builder->emitLookupInterfaceMethodInst( + builder->getTypeKind(), + primal, + autoDiffSharedContext->differentialAssocRefTypeStructKey); + auto diffWitness = builder->emitLookupInterfaceMethodInst( + (IRType*)primalDiffType, + primal, + autoDiffSharedContext->differentialAssocRefTypeWitnessStructKey); + + // Mark both as primal since we're working with types + // (which don't need transposing) + // + builder->markInstAsPrimal(primalDiffType); + builder->markInstAsPrimal(diffWitness); + + return InstPair(primal, diffWitness); + } } + auto decor = lookupInst->getRequirementKey()->findDecorationImpl( getInterfaceRequirementDerivativeDecorationOp()); + if (!decor) { return InstPair(primal, nullptr); @@ -589,6 +638,10 @@ IRInst* AutoDiffTranscriberBase::getDifferentialZeroOfType( { originalType = (IRType*)unwrapAttributedType(originalType); auto primalType = (IRType*)lookupPrimalInst(builder, originalType); + + // Can't generate zero for differentiable ptr types. Should never hit this case. + SLANG_ASSERT(!differentiableTypeConformanceContext.isDifferentiablePtrType(originalType)); + if (auto diffType = differentiateType(builder, originalType)) { IRInst* diffWitnessTable = nullptr; @@ -985,7 +1038,7 @@ IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst && !as<IRConstant>(pair.differential)) { auto primalType = (IRType*)(pair.primal->getDataType()); - builder->markInstAsDifferential(pair.differential, primalType); + markDiffTypeInst(builder, pair.differential, primalType); } } else @@ -997,7 +1050,7 @@ IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst if (as<IRType>(pair.differential)) break; auto mixedType = (IRType*)(pair.primal->getDataType()); - builder->markInstAsMixedDifferential(pair.primal, mixedType); + markDiffPairTypeInst(builder, pair.primal, mixedType); } } @@ -1076,4 +1129,64 @@ InstPair AutoDiffTranscriberBase::transcribeInst(IRBuilder* builder, IRInst* ori return result; } + +void AutoDiffTranscriberBase::markDiffTypeInst(IRBuilder* builder, IRInst* diffInst, IRType* primalType) +{ + // Ignore module-level insts. + if (as<IRModuleInst>(diffInst->getParent())) + return; + + // Also ignore generic-container-level insts. + if (as<IRBlock>(diffInst->getParent()) && + as<IRGeneric>(diffInst->getParent()->getParent())) + return; + + // TODO: This logic is a bit of a hack. We need to determine if the type is + // relevant to ptr-type computation or not, or more complex applications + // that use dynamic dispatch + ptr types will fail. + // + if (as<IRType>(diffInst)) + { + builder->markInstAsDifferential(diffInst, nullptr); + return; + } + + SLANG_ASSERT(diffInst); + SLANG_ASSERT(primalType); + + if (differentiableTypeConformanceContext.isDifferentiableValueType(primalType)) + { + builder->markInstAsDifferential(diffInst, primalType); + } + else if (differentiableTypeConformanceContext.isDifferentiablePtrType(primalType)) + { + builder->markInstAsPrimal(diffInst); + } + else + { + // Stop-gap solution to go with differential inst for now. + builder->markInstAsDifferential(diffInst, primalType); + } +} + +void AutoDiffTranscriberBase::markDiffPairTypeInst(IRBuilder* builder, IRInst* diffPairInst, IRType* pairType) +{ + SLANG_ASSERT(diffPairInst); + SLANG_ASSERT(pairType); + SLANG_ASSERT(as<IRDifferentialPairTypeBase>(pairType)); + + if (as<IRDifferentialPairType>(pairType)) + { + builder->markInstAsMixedDifferential(diffPairInst, pairType); + } + else if (as<IRDifferentialPtrPairType>(pairType)) + { + builder->markInstAsPrimal(diffPairInst); + } + else + { + SLANG_UNEXPECTED("unexpected differentiable type"); + } +} + } |
