diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2024-09-19 03:10:28 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-09-19 00:10:28 -0700 |
| commit | ccc310fa4e8096cda8a6c127aacc1a1fa9d8503a (patch) | |
| tree | 435e9c462a78fb848ab3b36c23287543d1a859de /source | |
| parent | 1781c2969eb65fb7ade01d3f0d7d9b8973bcd4d3 (diff) | |
Support `IDifferentiablePtrType` (#5031)
* initial diff-ref-type interface
* Initial support for `IDifferentiablePtrType`
* Fix unused vars
* More tests + fix switch case fallthrough.
* Update slang-ir-autodiff.cpp
* Update diff-ptr-type-loop.slang
* Add optimization to allow more complex pair types
* Update slang-ir-autodiff-primal-hoist.cpp
* Update diff-ptr-type-loop.slang
* Update slang-ir-autodiff-primal-hoist.cpp
* More fixes to address reviews
* Update slang-check-expr.cpp
* Optimizations + rename `differentiableRefInterfaceType` -> `differentiablePtrInterfaceType`
* Move pair logic to ir-builder, unify the type dictionaries.
---------
Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source')
24 files changed, 1215 insertions, 379 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index afcff8e65..476279ab8 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -285,6 +285,13 @@ interface IDifferentiable static Differential dmul(T, Differential); }; +__magic_type(DifferentiablePtrType) +interface IDifferentiablePtrType +{ + __builtin_requirement($( (int)BuiltinRequirementKind::DifferentialPtrType) ) + associatedtype Differential : IDifferentiablePtrType; +}; + /// Pair type that serves to wrap the primal and /// differential types of an arbitrary type T. @@ -357,6 +364,36 @@ struct DifferentialPair : IDifferentiable } }; +__generic<T : IDifferentiablePtrType> +__magic_type(DifferentialPtrPairType) +__intrinsic_type($(kIROp_DifferentialPtrPairType)) +struct DifferentialPtrPair : IDifferentiablePtrType +{ + typedef DifferentialPtrPair<T.Differential> Differential; + typedef T.Differential DifferentialElementType; + + __intrinsic_op($(kIROp_MakeDifferentialPtrPair)) + __init(T _primal, T.Differential _differential); + + property p : T + { + __intrinsic_op($(kIROp_DifferentialPtrPairGetPrimal)) + get; + } + + property v : T + { + __intrinsic_op($(kIROp_DifferentialPtrPairGetPrimal)) + get; + } + + property d : T.Differential + { + __intrinsic_op($(kIROp_DifferentialPtrPairGetDifferential)) + get; + } +}; + /// A type that uses a floating-point representation [sealed] diff --git a/source/slang/slang-ast-builder.cpp b/source/slang/slang-ast-builder.cpp index 9879a4187..b66af34fa 100644 --- a/source/slang/slang-ast-builder.cpp +++ b/source/slang/slang-ast-builder.cpp @@ -408,18 +408,32 @@ MatrixExpressionType* ASTBuilder::getMatrixType(Type* elementType, IntVal* rowCo DifferentialPairType* ASTBuilder::getDifferentialPairType( Type* valueType, - Witness* primalIsDifferentialWitness) + Witness* diffTypeWitness) { - Val* args[] = { valueType, primalIsDifferentialWitness }; + Val* args[] = { valueType, diffTypeWitness }; return as<DifferentialPairType>(getSpecializedBuiltinType(makeArrayView(args), "DifferentialPairType")); } +DifferentialPtrPairType* ASTBuilder::getDifferentialPtrPairType( + Type* valueType, + Witness* diffRefTypeWitness) +{ + Val* args[] = { valueType, diffRefTypeWitness }; + return as<DifferentialPtrPairType>(getSpecializedBuiltinType(makeArrayView(args), "DifferentialPtrPairType")); +} + DeclRef<InterfaceDecl> ASTBuilder::getDifferentiableInterfaceDecl() { DeclRef<InterfaceDecl> declRef = DeclRef<InterfaceDecl>(getBuiltinDeclRef("DifferentiableType", nullptr)); return declRef; } +DeclRef<InterfaceDecl> ASTBuilder::getDifferentiableRefInterfaceDecl() +{ + DeclRef<InterfaceDecl> declRef = DeclRef<InterfaceDecl>(getBuiltinDeclRef("DifferentiablePtrType", nullptr)); + return declRef; +} + bool ASTBuilder::isDifferentiableInterfaceAvailable() { return (m_sharedASTBuilder->tryFindMagicDecl("DifferentiableType") != nullptr); @@ -459,6 +473,11 @@ Type* ASTBuilder::getDifferentiableInterfaceType() return DeclRefType::create(this, getDifferentiableInterfaceDecl()); } +Type* ASTBuilder::getDifferentiableRefInterfaceType() +{ + return DeclRefType::create(this, getDifferentiableRefInterfaceDecl()); +} + DeclRef<Decl> ASTBuilder::getBuiltinDeclRef(const char* builtinMagicTypeName, Val* genericArg) { auto decl = m_sharedASTBuilder->findMagicDecl(builtinMagicTypeName); diff --git a/source/slang/slang-ast-builder.h b/source/slang/slang-ast-builder.h index b9b1f7ab8..08951513d 100644 --- a/source/slang/slang-ast-builder.h +++ b/source/slang/slang-ast-builder.h @@ -489,10 +489,17 @@ public: DifferentialPairType* getDifferentialPairType( Type* valueType, - Witness* primalIsDifferentialWitness); + Witness* diffTypeWitness); + + DifferentialPtrPairType* getDifferentialPtrPairType( + Type* valueType, + Witness* diffRefTypeWitness); DeclRef<InterfaceDecl> getDifferentiableInterfaceDecl(); + DeclRef<InterfaceDecl> getDifferentiableRefInterfaceDecl(); + Type* getDifferentiableInterfaceType(); + Type* getDifferentiableRefInterfaceType(); bool isDifferentiableInterfaceAvailable(); diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h index 83a4cf353..56101bb91 100644 --- a/source/slang/slang-ast-support-types.h +++ b/source/slang/slang-ast-support-types.h @@ -9,7 +9,7 @@ #include "slang-profile.h" #include "slang-type-system-shared.h" -#include "slang.h" +#include "../../include/slang.h" #include "../core/slang-semantic-version.h" @@ -1606,6 +1606,7 @@ namespace Slang DefaultInitializableConstructor, ///< The `IDefaultInitializable.__init()` method DifferentialType, ///< The `IDifferentiable.Differential` associated type requirement + DifferentialPtrType, ///< The `IDifferentiable.DifferentialPtr` associated type requirement DZeroFunc, ///< The `IDifferentiable.dzero` function requirement DAddFunc, ///< The `IDifferentiable.dadd` function requirement DMulFunc, ///< The `IDifferentiable.dmul` function requirement diff --git a/source/slang/slang-ast-type.h b/source/slang/slang-ast-type.h index 401d73e29..46ea3ea55 100644 --- a/source/slang/slang-ast-type.h +++ b/source/slang/slang-ast-type.h @@ -462,11 +462,22 @@ class DifferentialPairType : public ArithmeticExpressionType Type* getPrimalType(); }; +class DifferentialPtrPairType : public ArithmeticExpressionType +{ + SLANG_AST_CLASS(DifferentialPtrPairType) + Type* getPrimalRefType(); +}; + class DifferentiableType : public BuiltinType { SLANG_AST_CLASS(DifferentiableType) }; +class DifferentiablePtrType : public BuiltinType +{ + SLANG_AST_CLASS(DifferentiablePtrType) +}; + class DefaultInitializableType : public BuiltinType { SLANG_AST_CLASS(DefaultInitializableType); diff --git a/source/slang/slang-check-conformance.cpp b/source/slang/slang-check-conformance.cpp index ffa037996..9d9047e41 100644 --- a/source/slang/slang-check-conformance.cpp +++ b/source/slang/slang-check-conformance.cpp @@ -274,9 +274,14 @@ namespace Slang return isInterfaceType(type); } - bool SemanticsVisitor::isTypeDifferentiable(Type* type) + SubtypeWitness* SemanticsVisitor::isTypeDifferentiable(Type* type) { - return isSubtype(type, m_astBuilder->getDiffInterfaceType(), IsSubTypeOptions::None); + if (auto valueWitness = isSubtype(type, m_astBuilder->getDiffInterfaceType(), IsSubTypeOptions::None)) + return valueWitness; + else if (auto ptrWitness = isSubtype(type, m_astBuilder->getDifferentiableRefInterfaceType(), IsSubTypeOptions::None)) + return ptrWitness; + + return nullptr; } bool SemanticsVisitor::doesTypeHaveTag(Type* type, TypeTag tag) diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index deb8c55eb..8e78ff084 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -10204,7 +10204,8 @@ namespace Slang bool isDiffParam = (!param->findModifier<NoDiffModifier>()); if (isDiffParam) { - if (auto pairType = as<DifferentialPairType>(visitor->getDifferentialPairType(param->getType()))) + auto diffPair = visitor->getDifferentialPairType(param->getType()); + if (auto pairType = as<DifferentialPairType>(diffPair)) { arg->type.type = pairType; arg->type.isLeftValue = true; @@ -10225,6 +10226,11 @@ namespace Slang direction = ParameterDirection::kParameterDirection_InOut; } } + else if (auto refPairType = as<DifferentialPtrPairType>(diffPair)) + { + // no need to change direction of ref-pairs. + arg->type.type = refPairType; + } else { isDiffParam = false; diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 8f24ec5b0..5233008fd 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -1131,7 +1131,8 @@ namespace Slang { if (auto builtinRequirement = declRefType->getDeclRef().getDecl()->findModifier<BuiltinRequirementModifier>()) { - if (builtinRequirement->kind == BuiltinRequirementKind::DifferentialType) + if (builtinRequirement->kind == BuiltinRequirementKind::DifferentialType + || builtinRequirement->kind == BuiltinRequirementKind::DifferentialPtrType) { // We are trying to get differential type from a differential type. // The result is itself. @@ -1139,7 +1140,10 @@ namespace Slang } } type = resolveType(type); - if (const auto witness = as<SubtypeWitness>(tryGetInterfaceConformanceWitness(type, builder->getDifferentiableInterfaceType()))) + auto witness = as<SubtypeWitness>(tryGetInterfaceConformanceWitness(type, builder->getDifferentiableInterfaceType())); + if (!witness) + witness = as<SubtypeWitness>(tryGetInterfaceConformanceWitness(type, builder->getDifferentiableRefInterfaceType())); + if (witness) { auto diffTypeLookupResult = lookUpMember( getASTBuilder(), @@ -1367,6 +1371,13 @@ namespace Slang { addDifferentiableTypeToDiffTypeRegistry((DeclRefType*)type, subtypeWitness); } + + if (auto subtypeWitness = as<SubtypeWitness>( + tryGetInterfaceConformanceWitness(type, getASTBuilder()->getDifferentiableRefInterfaceType()))) + { + addDifferentiableTypeToDiffTypeRegistry((DeclRefType*)type, subtypeWitness); + } + if (auto aggTypeDeclRef = declRefType->getDeclRef().as<AggTypeDecl>()) { foreachDirectOrExtensionMemberOfType<InheritanceDecl>(this, aggTypeDeclRef, [&](DeclRef<InheritanceDecl> member) @@ -2899,18 +2910,25 @@ namespace Slang return m_astBuilder->getExpandType(diffPairEachType, makeArrayViewSingle(primalType)); } } + // Get a reference to the builtin 'IDifferentiable' interface auto differentiableInterface = getASTBuilder()->getDifferentiableInterfaceType(); + auto differentiableRefInterface = getASTBuilder()->getDifferentiableRefInterfaceType(); - auto conformanceWitness = as<Witness>(isSubtype(primalType, differentiableInterface, IsSubTypeOptions::None)); // Check if the provided type inherits from IDifferentiable. // If not, return the original type. - if (conformanceWitness) + if (auto conformanceWitness = isTypeDifferentiable(primalType)) { - return m_astBuilder->getDifferentialPairType(primalType, conformanceWitness); + if (conformanceWitness->getSup() == differentiableInterface) + { + return m_astBuilder->getDifferentialPairType(primalType, conformanceWitness); + } + else if (conformanceWitness->getSup() == differentiableRefInterface) + { + return m_astBuilder->getDifferentialPtrPairType(primalType, conformanceWitness); + } } - else - return primalType; + return primalType; } Type* SemanticsVisitor::getForwardDiffFuncType(FuncType* originalType) diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index adb7e81f3..29a57ae35 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -2208,7 +2208,7 @@ namespace Slang bool isValidGenericConstraintType(Type* type); - bool isTypeDifferentiable(Type* type); + SubtypeWitness* isTypeDifferentiable(Type* type); bool doesTypeHaveTag(Type* type, TypeTag tag); diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index fe7c77ba0..609bcd8a3 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -336,8 +336,8 @@ InstPair ForwardDiffTranscriber::transcribeLoad(IRBuilder* builder, IRLoad* orig { auto origPtr = origLoad->getPtr(); auto primalPtr = lookupPrimalInst(builder, origPtr, nullptr); - auto primalPtrType = as<IRPtrTypeBase>(primalPtr->getFullType()); - if (primalPtrType) + + if (auto primalPtrType = as<IRPtrTypeBase>(primalPtr->getFullType())) { if (auto diffPairType = as<IRDifferentialPairType>(primalPtrType->getValueType())) { @@ -355,6 +355,18 @@ InstPair ForwardDiffTranscriber::transcribeLoad(IRBuilder* builder, IRLoad* orig (IRType*)differentiableTypeConformanceContext.getDiffTypeFromPairType(builder, diffPairType), load); return InstPair(primalElement, diffElement); } + else if (auto diffPtrPairType = as<IRDifferentialPtrPairType>(primalPtrType->getValueType())) + { + auto load = builder->emitLoad(primalPtr); + builder->markInstAsPrimal(load); + + auto primalElement = builder->emitDifferentialPtrPairGetPrimal(load); + auto diffElement = builder->emitDifferentialPtrPairGetDifferential( + (IRType*)differentiableTypeConformanceContext.getDiffTypeFromPairType(builder, diffPtrPairType), load); + builder->markInstAsPrimal(primalElement); + builder->markInstAsPrimal(diffElement); + return InstPair(primalElement, diffElement); + } } auto primalLoad = maybeCloneForPrimalInst(builder, origLoad); @@ -389,6 +401,16 @@ InstPair ForwardDiffTranscriber::transcribeStore(IRBuilder* builder, IRStore* or return InstPair(store, nullptr); } + else if (auto diffRefPairType = as<IRDifferentialPtrPairType>(primalLocationPtrType->getValueType())) + { + auto valToStore = builder->emitMakeDifferentialPtrPair(diffRefPairType, primalStoreVal, diffStoreVal); + builder->markInstAsPrimal(valToStore); + + auto store = builder->emitStore(primalStoreLocation, valToStore); + builder->markInstAsPrimal(store); + + return InstPair(store, nullptr); + } } auto primalStore = maybeCloneForPrimalInst(builder, origStore); @@ -404,7 +426,7 @@ InstPair ForwardDiffTranscriber::transcribeStore(IRBuilder* builder, IRStore* or // Default case, storing the entire type (and not a member) diffStore = as<IRStore>( builder->emitStore(diffStoreLocation, diffStoreVal)); - + markDiffTypeInst(builder, diffStore, primalStoreVal->getDataType()); return InstPair(primalStore, diffStore); } @@ -696,14 +718,16 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig if (auto pairType = tryGetDiffPairType(&argBuilder, primalType)) { auto pairPtrType = as<IRPtrTypeBase>(pairType); - auto pairValType = as<IRDifferentialPairType>( + + auto pairValType = as<IRDifferentialPairTypeBase>( pairPtrType ? pairPtrType->getValueType() : pairType); + auto diffType = differentiableTypeConformanceContext.getDiffTypeFromPairType(&argBuilder, pairValType); if (auto ptrParamType = as<IRPtrTypeBase>(diffParamType)) { // Create temp var to pass in/out arguments. auto srcVar = argBuilder.emitVar(pairValType); - argBuilder.markInstAsMixedDifferential(srcVar, pairValType->getValueType()); + markDiffPairTypeInst(&argBuilder, srcVar, pairValType); auto diffArg = findOrTranscribeDiffInst(&argBuilder, origArg); if (ptrParamType->getOp() == kIROp_InOutType) @@ -716,28 +740,28 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig else { diffArgVal = argBuilder.emitLoad(diffArg); - argBuilder.markInstAsDifferential(diffArgVal, pairValType->getValueType()); + markDiffTypeInst(&argBuilder, diffArgVal, pairValType->getValueType()); } auto initVal = argBuilder.emitMakeDifferentialPair(pairValType, primalVal, diffArgVal); - argBuilder.markInstAsMixedDifferential(initVal, primalType); + markDiffPairTypeInst(&argBuilder, initVal, pairValType); auto store = argBuilder.emitStore(srcVar, initVal); - argBuilder.markInstAsMixedDifferential(store, primalType); + markDiffPairTypeInst(&argBuilder, store, pairValType); } if (as<IROutTypeBase>(ptrParamType)) { // Read back new value. auto newVal = afterBuilder.emitLoad(srcVar); - afterBuilder.markInstAsMixedDifferential(newVal, pairValType->getValueType()); + markDiffPairTypeInst(&afterBuilder, newVal, pairValType); auto newPrimalVal = afterBuilder.emitDifferentialPairGetPrimal(pairValType->getValueType(), newVal); afterBuilder.emitStore(primalArg, newPrimalVal); if (diffArg) { auto newDiffVal = afterBuilder.emitDifferentialPairGetDifferential((IRType*)diffType, newVal); - afterBuilder.markInstAsDifferential(newDiffVal, pairValType->getValueType()); + markDiffTypeInst(&afterBuilder, newDiffVal, pairValType->getValueType()); auto storeInst = afterBuilder.emitStore(diffArg, newDiffVal); - afterBuilder.markInstAsDifferential(storeInst, pairValType->getValueType()); + markDiffTypeInst(&afterBuilder, storeInst, pairValType->getValueType()); } } args.add(srcVar); @@ -753,7 +777,7 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig SLANG_RELEASE_ASSERT(diffArg); auto diffPair = argBuilder.emitMakeDifferentialPair(pairType, primalArg, diffArg); - argBuilder.markInstAsMixedDifferential(diffPair, pairType); + markDiffPairTypeInst(&argBuilder, diffPair, pairType); args.add(diffPair); continue; @@ -779,12 +803,13 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig diffCallee, args); placeholderCall->removeAndDeallocate(); + argBuilder.markInstAsMixedDifferential(callInst, diffReturnType); argBuilder.addAutoDiffOriginalValueDecoration(callInst, primalCallee); *builder = afterBuilder; - if (diffReturnType->getOp() == kIROp_DifferentialPairType) + if (as<IRDifferentialPairType>(diffReturnType) || as<IRDifferentialPtrPairType>(diffReturnType)) { IRInst* primalResultValue = afterBuilder.emitDifferentialPairGetPrimal(callInst); auto diffType = differentiateType(&afterBuilder, origCall->getFullType()); @@ -1751,12 +1776,13 @@ InstPair ForwardDiffTranscriber::transcribeFunc(IRBuilder* inBuilder, IRFunc* pr IRInst* valToStore = nullptr; if (writeBack.value.differential) { + auto pairValType = cast<IRPtrTypeBase>(param->getFullType())->getValueType(); auto diffVal = builder.emitLoad(writeBack.value.differential); - builder.markInstAsDifferential(diffVal, primalVal->getFullType()); + markDiffTypeInst(&builder, diffVal, primalVal->getFullType()); - valToStore = builder.emitMakeDifferentialPair(cast<IRPtrTypeBase>(param->getFullType())->getValueType(), - primalVal, diffVal); - builder.markInstAsMixedDifferential(valToStore, valToStore->getFullType()); + valToStore = builder.emitMakeDifferentialPair(pairValType, primalVal, diffVal); + + markDiffPairTypeInst(&builder, valToStore, pairValType); } else { @@ -1767,7 +1793,7 @@ InstPair ForwardDiffTranscriber::transcribeFunc(IRBuilder* inBuilder, IRFunc* pr if (writeBack.value.differential) { - builder.markInstAsMixedDifferential(storeInst, valToStore->getFullType()); + markDiffPairTypeInst(&builder, storeInst, valToStore->getFullType()); } } } @@ -2043,24 +2069,25 @@ InstPair ForwardDiffTranscriber::transcribeFuncParam(IRBuilder* builder, IRParam SLANG_ASSERT(diffPairParam); - if (auto pairType = as<IRDifferentialPairType>(diffPairType)) + if (as<IRDifferentialPairType>(diffPairType) || as<IRDifferentialPtrPairType>(diffPairType)) { return InstPair( builder->emitDifferentialPairGetPrimal(diffPairParam), builder->emitDifferentialPairGetDifferential( - (IRType*)differentiableTypeConformanceContext.getDiffTypeFromPairType(builder, pairType), + (IRType*)differentiableTypeConformanceContext.getDiffTypeFromPairType( + builder, + as<IRDifferentialPairTypeBase>(diffPairType)), diffPairParam)); } else if (auto pairPtrType = as<IRPtrTypeBase>(diffPairType)) { - auto ptrInnerPairType = as<IRDifferentialPairType>(pairPtrType->getValueType()); + auto ptrInnerPairType = as<IRDifferentialPairTypeBase>(pairPtrType->getValueType()); // Make a local copy of the parameter for primal and diff parts. auto primal = builder->emitVar(ptrInnerPairType->getValueType()); auto diffType = differentiateType(builder, cast<IRPtrTypeBase>(origParam->getDataType())->getValueType()); auto diff = builder->emitVar(diffType); - builder->markInstAsDifferential( - diff, builder->getPtrType(ptrInnerPairType->getValueType())); + markDiffTypeInst(builder, diff, builder->getPtrType(ptrInnerPairType->getValueType())); IRInst* primalInitVal = nullptr; IRInst* diffInitVal = nullptr; @@ -2072,17 +2099,18 @@ InstPair ForwardDiffTranscriber::transcribeFuncParam(IRBuilder* builder, IRParam else { auto initVal = builder->emitLoad(diffPairParam); - builder->markInstAsMixedDifferential(initVal, ptrInnerPairType); + markDiffPairTypeInst(builder, initVal, ptrInnerPairType); primalInitVal = builder->emitDifferentialPairGetPrimal(initVal); diffInitVal = builder->emitDifferentialPairGetDifferential(diffType, initVal); } - builder->markInstAsDifferential(diffInitVal, ptrInnerPairType->getValueType()); + markDiffTypeInst(builder, diffInitVal, ptrInnerPairType->getValueType()); + builder->emitStore(primal, primalInitVal); auto diffStore = builder->emitStore(diff, diffInitVal); - builder->markInstAsDifferential(diffStore, ptrInnerPairType->getValueType()); + markDiffTypeInst(builder, diffStore, ptrInnerPairType->getValueType()); mapInOutParamToWriteBackValue[diffPairParam] = InstPair(primal, diff); return InstPair(primal, diff); diff --git a/source/slang/slang-ir-autodiff-pairs.cpp b/source/slang/slang-ir-autodiff-pairs.cpp index 7fc8ebbe6..3a6d52bea 100644 --- a/source/slang/slang-ir-autodiff-pairs.cpp +++ b/source/slang/slang-ir-autodiff-pairs.cpp @@ -107,10 +107,13 @@ struct DiffPairLoweringPass : InstPassBase case kIROp_DifferentialPairGetPrimal: case kIROp_DifferentialPairGetDifferentialUserCode: case kIROp_DifferentialPairGetPrimalUserCode: + case kIROp_DifferentialPtrPairGetDifferential: + case kIROp_DifferentialPtrPairGetPrimal: lowerPairAccess(builder, inst); break; case kIROp_MakeDifferentialPairUserCode: + case kIROp_MakeDifferentialPtrPair: lowerMakePair(builder, inst); break; diff --git a/source/slang/slang-ir-autodiff-primal-hoist.cpp b/source/slang/slang-ir-autodiff-primal-hoist.cpp index f51178f0f..2881abe3e 100644 --- a/source/slang/slang-ir-autodiff-primal-hoist.cpp +++ b/source/slang/slang-ir-autodiff-primal-hoist.cpp @@ -892,6 +892,16 @@ void applyToInst( } } SLANG_ASSERT(replacement); + + // If the replacement and inst are not the exact same type, use an int-cast + // (e.g. uint vs. int) + // + if (replacement->getDataType() != inst->getDataType()) + { + setInsertAfterOrdinaryInst(builder, replacement); + replacement = builder->emitCast(inst->getDataType(), replacement); + } + cloneCtx->cloneEnv.mapOldValToNew[inst] = replacement; cloneCtx->registerClonedInst(builder, inst, replacement); return; @@ -1998,6 +2008,7 @@ static bool shouldStoreInst(IRInst* inst) case kIROp_MakeArrayFromElement: case kIROp_MakeDifferentialPair: case kIROp_MakeDifferentialPairUserCode: + case kIROp_MakeDifferentialPtrPair: case kIROp_MakeOptionalNone: case kIROp_MakeOptionalValue: case kIROp_MakeExistential: @@ -2005,6 +2016,8 @@ static bool shouldStoreInst(IRInst* inst) case kIROp_DifferentialPairGetPrimal: case kIROp_DifferentialPairGetDifferentialUserCode: case kIROp_DifferentialPairGetPrimalUserCode: + case kIROp_DifferentialPtrPairGetDifferential: + case kIROp_DifferentialPtrPairGetPrimal: case kIROp_ExtractExistentialValue: case kIROp_ExtractExistentialType: case kIROp_ExtractExistentialWitnessTable: diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index 2fb73c4ac..169dd31ee 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -152,6 +152,16 @@ namespace Slang builder->emitBlock(); params = _defineFuncParams(builder, as<IRFunc>(existingPrimalFunc)); params.removeLast(); + + // Unwrap any ref pairs. We need this special case for trivial funcs. + for (Int i = 0; i < params.getCount(); i++) + { + if (as<IRDifferentialPtrPairType>(params[i]->getDataType())) + { + params[i] = builder->emitDifferentialPtrPairGetPrimal(params[i]); + } + } + IRInst* originalFuncRefFromPrimalFunc = originalFunc; if (originalGeneric) originalFuncRefFromPrimalFunc = maybeSpecializeWithGeneric(*builder, originalGeneric, existingPriamlFuncGeneric); @@ -266,7 +276,20 @@ namespace Slang if (auto primalNoDiffType = _getPrimalTypeFromNoDiffType(this, builder, paramType)) return primalNoDiffType; - return (IRType*)findOrTranscribePrimalInst(builder, paramType); + auto primalType = (IRType*)findOrTranscribePrimalInst(builder, paramType); + + // Differentiable pointer types are treated as primal pairs, since they aren't involved in the transposition + // process. + // + if (differentiableTypeConformanceContext.isDifferentiablePtrType(primalType)) + { + auto diffPairType = tryGetDiffPairType(builder, primalType); + SLANG_ASSERT(diffPairType); + + return diffPairType; + } + + return primalType; } IRType* BackwardDiffTranscriberBase::transcribeParamTypeForPropagateFunc(IRBuilder* builder, IRType* paramType) @@ -292,7 +315,7 @@ namespace Slang auto diffPairType = tryGetDiffPairType(builder, paramType); if (diffPairType) { - if (!as<IRPtrTypeBase>(diffPairType)) + if (!as<IRPtrTypeBase>(diffPairType) && !as<IRDifferentialPtrPairType>(diffPairType)) return builder->getInOutType(diffPairType); return diffPairType; } @@ -961,7 +984,7 @@ namespace Slang // Initialize the var with input diff param at start. // Note that we insert the store in the primal block so it won't get transposed. auto storeInst = nextBlockBuilder.emitStore(tempVar, diffParam); - nextBlockBuilder.markInstAsDifferential(storeInst, diffPairType); + nextBlockBuilder.markInstAsDifferential(storeInst, primalType); // Since this store inst is specific to propagate function, we track it in a // set so we can remove it when we generate the primal func. result.propagateFuncSpecificPrimalInsts.add(storeInst); diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index 1fa76c730..2141837b5 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -174,45 +174,54 @@ IRInst* AutoDiffTranscriberBase::maybeCloneForPrimalInst(IRBuilder* builder, IRI IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementKey); -IRInst* AutoDiffTranscriberBase::tryGetDifferentiableWitness(IRBuilder* builder, IRInst* originalType) +IRInst* AutoDiffTranscriberBase::tryGetDifferentiableWitness(IRBuilder* builder, IRInst* originalType, DiffConformanceKind kind) { - return differentiableTypeConformanceContext.tryGetDifferentiableWitness(builder, originalType); + if (kind == DiffConformanceKind::Any) + { + if (auto valueWitness = differentiableTypeConformanceContext.tryGetDifferentiableWitness(builder, originalType, DiffConformanceKind::Value)) + return valueWitness; + if (auto ptrWitness = differentiableTypeConformanceContext.tryGetDifferentiableWitness(builder, originalType, DiffConformanceKind::Ptr)) + return ptrWitness; + } + else + { + return differentiableTypeConformanceContext.tryGetDifferentiableWitness(builder, originalType, kind); + } + return nullptr; } IRType* AutoDiffTranscriberBase::getOrCreateDiffPairType(IRBuilder* builder, IRInst* primalType, IRInst* witness) { - return builder->getDifferentialPairType( - (IRType*)primalType, - witness); + auto conformanceType = differentiableTypeConformanceContext.getConformanceTypeFromWitness(witness); + if (autoDiffSharedContext->isInterfaceAvailable && + conformanceType == autoDiffSharedContext->differentiableInterfaceType) + { + return builder->getDifferentialPairType((IRType*)primalType, witness); + } + else if (autoDiffSharedContext->isPtrInterfaceAvailable && + conformanceType == autoDiffSharedContext->differentiablePtrInterfaceType) + { + return builder->getDifferentialPtrPairType((IRType*)primalType, witness); + } + else + { + SLANG_UNEXPECTED("Unexpected witness type"); + } } IRType* AutoDiffTranscriberBase::getOrCreateDiffPairType(IRBuilder* builder, IRInst* originalType) { - auto primalType = lookupPrimalInst(builder, originalType, nullptr); + auto primalType = lookupPrimalInst(builder, originalType, originalType); SLANG_RELEASE_ASSERT(primalType); IRInst* witness = nullptr; - if (auto lookup = as<IRLookupWitnessMethod>(primalType)) - { - if (lookup->getRequirementKey() == autoDiffSharedContext->differentialAssocTypeStructKey) - { - witness = builder->emitLookupInterfaceMethodInst( - lookup->getWitnessTable()->getDataType(), - lookup->getWitnessTable(), - autoDiffSharedContext->differentialAssocTypeWitnessStructKey); - } - } - - // Obtain the witness that primalType conforms to IDifferentiable. + + // Obtain the witness that primalType conforms to IDifferentiable/IDifferentiablePtrType if (!witness) - witness = tryGetDifferentiableWitness(builder, originalType); + witness = tryGetDifferentiableWitness(builder, primalType, DiffConformanceKind::Any); SLANG_RELEASE_ASSERT(witness); - auto pairType = builder->getDifferentialPairType( - (IRType*)primalType, - witness); - - return pairType; + return getOrCreateDiffPairType(builder, primalType, witness); } IRType* AutoDiffTranscriberBase::differentiateType(IRBuilder* builder, IRType* origType) @@ -223,8 +232,10 @@ IRType* AutoDiffTranscriberBase::differentiateType(IRBuilder* builder, IRType* o // Special-case for differentiable existential types. if (as<IRInterfaceType>(origType) || as<IRAssociatedType>(origType)) { - if (differentiableTypeConformanceContext.lookUpConformanceForType(origType)) + if (differentiableTypeConformanceContext.lookUpConformanceForType(origType, DiffConformanceKind::Value)) return autoDiffSharedContext->differentiableInterfaceType; + else if (differentiableTypeConformanceContext.lookUpConformanceForType(origType, DiffConformanceKind::Ptr)) + return autoDiffSharedContext->differentiablePtrInterfaceType; else return nullptr; } @@ -278,8 +289,9 @@ IRType* AutoDiffTranscriberBase::_differentiateTypeImpl(IRBuilder* builder, IRTy } case kIROp_DifferentialPairType: + case kIROp_DifferentialPtrPairType: { - auto primalPairType = as<IRDifferentialPairType>(primalType); + auto primalPairType = as<IRDifferentialPairTypeBase>(primalType); return getOrCreateDiffPairType( builder, differentiableTypeConformanceContext.getDiffTypeFromPairType(builder, primalPairType), @@ -445,8 +457,24 @@ IRType* AutoDiffTranscriberBase::differentiateExtractExistentialType(IRBuilder* auto interfaceType = as<IRInterfaceType>(unwrapAttributedType(origType->getOperand(0)->getDataType())); if (!interfaceType) return nullptr; - List<IRInterfaceRequirementEntry*> lookupKeyPath = differentiableTypeConformanceContext.findDifferentiableInterfaceLookupPath( + + List<IRInterfaceRequirementEntry*> lookupKeyPath; + IRStructKey* diffStructKey = nullptr; + + List<IRInterfaceRequirementEntry*> lookupPathValueType = differentiableTypeConformanceContext.findInterfaceLookupPath( autoDiffSharedContext->differentiableInterfaceType, interfaceType); + if (lookupPathValueType.getCount() > 0) + { + lookupKeyPath = lookupPathValueType; + diffStructKey = autoDiffSharedContext->differentialAssocTypeStructKey; + } + else + { + // Try IDifferentiablePtrType + lookupKeyPath = differentiableTypeConformanceContext.findInterfaceLookupPath( + autoDiffSharedContext->differentiablePtrInterfaceType, interfaceType); + diffStructKey = autoDiffSharedContext->differentialAssocRefTypeStructKey; + } if (lookupKeyPath.getCount()) { @@ -456,7 +484,7 @@ IRType* AutoDiffTranscriberBase::differentiateExtractExistentialType(IRBuilder* { outWitnessTable = builder->emitLookupInterfaceMethodInst((IRType*)node->getRequirementVal(), outWitnessTable, node->getRequirementKey()); } - auto diffType = builder->emitLookupInterfaceMethodInst(builder->getTypeType(), outWitnessTable, autoDiffSharedContext->differentialAssocTypeStructKey); + auto diffType = builder->emitLookupInterfaceMethodInst(builder->getTypeType(), outWitnessTable, diffStructKey); return (IRType*)diffType; } return nullptr; @@ -561,10 +589,31 @@ InstPair AutoDiffTranscriberBase::transcribeLookupInterfaceMethod(IRBuilder* bui return InstPair(primal, diffWitness); } + else if (returnWitnessType->getConformanceType() == autoDiffSharedContext->differentiablePtrInterfaceType) + { + auto primalDiffType = builder->emitLookupInterfaceMethodInst( + builder->getTypeKind(), + primal, + autoDiffSharedContext->differentialAssocRefTypeStructKey); + auto diffWitness = builder->emitLookupInterfaceMethodInst( + (IRType*)primalDiffType, + primal, + autoDiffSharedContext->differentialAssocRefTypeWitnessStructKey); + + // Mark both as primal since we're working with types + // (which don't need transposing) + // + builder->markInstAsPrimal(primalDiffType); + builder->markInstAsPrimal(diffWitness); + + return InstPair(primal, diffWitness); + } } + auto decor = lookupInst->getRequirementKey()->findDecorationImpl( getInterfaceRequirementDerivativeDecorationOp()); + if (!decor) { return InstPair(primal, nullptr); @@ -589,6 +638,10 @@ IRInst* AutoDiffTranscriberBase::getDifferentialZeroOfType( { originalType = (IRType*)unwrapAttributedType(originalType); auto primalType = (IRType*)lookupPrimalInst(builder, originalType); + + // Can't generate zero for differentiable ptr types. Should never hit this case. + SLANG_ASSERT(!differentiableTypeConformanceContext.isDifferentiablePtrType(originalType)); + if (auto diffType = differentiateType(builder, originalType)) { IRInst* diffWitnessTable = nullptr; @@ -985,7 +1038,7 @@ IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst && !as<IRConstant>(pair.differential)) { auto primalType = (IRType*)(pair.primal->getDataType()); - builder->markInstAsDifferential(pair.differential, primalType); + markDiffTypeInst(builder, pair.differential, primalType); } } else @@ -997,7 +1050,7 @@ IRInst* AutoDiffTranscriberBase::transcribe(IRBuilder* builder, IRInst* origInst if (as<IRType>(pair.differential)) break; auto mixedType = (IRType*)(pair.primal->getDataType()); - builder->markInstAsMixedDifferential(pair.primal, mixedType); + markDiffPairTypeInst(builder, pair.primal, mixedType); } } @@ -1076,4 +1129,64 @@ InstPair AutoDiffTranscriberBase::transcribeInst(IRBuilder* builder, IRInst* ori return result; } + +void AutoDiffTranscriberBase::markDiffTypeInst(IRBuilder* builder, IRInst* diffInst, IRType* primalType) +{ + // Ignore module-level insts. + if (as<IRModuleInst>(diffInst->getParent())) + return; + + // Also ignore generic-container-level insts. + if (as<IRBlock>(diffInst->getParent()) && + as<IRGeneric>(diffInst->getParent()->getParent())) + return; + + // TODO: This logic is a bit of a hack. We need to determine if the type is + // relevant to ptr-type computation or not, or more complex applications + // that use dynamic dispatch + ptr types will fail. + // + if (as<IRType>(diffInst)) + { + builder->markInstAsDifferential(diffInst, nullptr); + return; + } + + SLANG_ASSERT(diffInst); + SLANG_ASSERT(primalType); + + if (differentiableTypeConformanceContext.isDifferentiableValueType(primalType)) + { + builder->markInstAsDifferential(diffInst, primalType); + } + else if (differentiableTypeConformanceContext.isDifferentiablePtrType(primalType)) + { + builder->markInstAsPrimal(diffInst); + } + else + { + // Stop-gap solution to go with differential inst for now. + builder->markInstAsDifferential(diffInst, primalType); + } +} + +void AutoDiffTranscriberBase::markDiffPairTypeInst(IRBuilder* builder, IRInst* diffPairInst, IRType* pairType) +{ + SLANG_ASSERT(diffPairInst); + SLANG_ASSERT(pairType); + SLANG_ASSERT(as<IRDifferentialPairTypeBase>(pairType)); + + if (as<IRDifferentialPairType>(pairType)) + { + builder->markInstAsMixedDifferential(diffPairInst, pairType); + } + else if (as<IRDifferentialPtrPairType>(pairType)) + { + builder->markInstAsPrimal(diffPairInst); + } + else + { + SLANG_UNEXPECTED("unexpected differentiable type"); + } +} + } diff --git a/source/slang/slang-ir-autodiff-transcriber-base.h b/source/slang/slang-ir-autodiff-transcriber-base.h index f7f2dd6f2..9f3cfe56f 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.h +++ b/source/slang/slang-ir-autodiff-transcriber-base.h @@ -91,7 +91,7 @@ struct AutoDiffTranscriberBase void maybeMigrateDifferentiableDictionaryFromDerivativeFunc(IRBuilder* builder, IRInst* origFunc); - IRInst* tryGetDifferentiableWitness(IRBuilder* builder, IRInst* originalType); + IRInst* tryGetDifferentiableWitness(IRBuilder* builder, IRInst* originalType, DiffConformanceKind kind); IRType* getOrCreateDiffPairType(IRBuilder* builder, IRInst* primalType, IRInst* witness); @@ -152,6 +152,10 @@ struct AutoDiffTranscriberBase virtual InstPair transcribeInstImpl(IRBuilder* builder, IRInst* origInst) = 0; virtual IROp getInterfaceRequirementDerivativeDecorationOp() = 0; + + void markDiffTypeInst(IRBuilder* builder, IRInst* inst, IRType* primalType); + + void markDiffPairTypeInst(IRBuilder* builder, IRInst* inst, IRType* primalType); }; } diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h index 1f8c3052e..8669df5a4 100644 --- a/source/slang/slang-ir-autodiff-transpose.h +++ b/source/slang/slang-ir-autodiff-transpose.h @@ -2116,7 +2116,8 @@ struct DiffTransposePass // If we reach this point, revValue must be a differentiable type. auto revTypeWitness = diffTypeContext.tryGetDifferentiableWitness( builder, - primalType); + primalType, + DiffConformanceKind::Value); SLANG_ASSERT(revTypeWitness); auto baseExistential = fwdInst->getOperand(0); diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp index 0953c535a..507a2bf92 100644 --- a/source/slang/slang-ir-autodiff-unzip.cpp +++ b/source/slang/slang-ir-autodiff-unzip.cpp @@ -144,7 +144,10 @@ struct ExtractPrimalFuncContext } auto structField = genTypeBuilder.createStructField(structType, structKey, (IRType*)fieldType); - if (auto witness = backwardPrimalTranscriber->tryGetDifferentiableWitness(&genTypeBuilder, (IRType*)fieldType)) + if (auto witness = backwardPrimalTranscriber->tryGetDifferentiableWitness( + &genTypeBuilder, + (IRType*)fieldType, + DiffConformanceKind::Value)) { genTypeBuilder.addIntermediateContextFieldDifferentialTypeDecoration(structField, witness); } diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index 07a6a76fb..94a605a68 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -25,7 +25,7 @@ bool isBackwardDifferentiableFunc(IRInst* func) return false; } -IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementKey) +IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementKey, IRType* resultType = nullptr) { if (auto witnessTable = as<IRWitnessTable>(witness)) { @@ -53,15 +53,16 @@ IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementK } else { + SLANG_ASSERT(resultType); return builder->emitLookupInterfaceMethodInst( - builder->getTypeKind(), + resultType, witness, requirementKey); } return nullptr; } -static IRInst* _getDiffTypeFromPairType(AutoDiffSharedContext*sharedContext, IRBuilder* builder, IRDifferentialPairTypeBase* type) +static IRInst* _getDiffTypeFromPairType(AutoDiffSharedContext* sharedContext, IRBuilder* builder, IRDifferentialPairTypeBase* type) { auto witness = type->getWitness(); SLANG_RELEASE_ASSERT(witness); @@ -70,16 +71,48 @@ static IRInst* _getDiffTypeFromPairType(AutoDiffSharedContext*sharedContext, IRB if (as<IRInterfaceType>(type->getValueType()) || as<IRAssociatedType>(type->getValueType())) { // The differential type is the IDifferentiable interface type. - return sharedContext->differentiableInterfaceType; + if (as<IRDifferentialPairType>(type) || as<IRDifferentialPairUserCodeType>(type)) + return sharedContext->differentiableInterfaceType; + else if (as<IRDifferentialPtrPairType>(type)) + return sharedContext->differentiablePtrInterfaceType; + else + SLANG_UNEXPECTED("Unexpected differential pair type"); } - return _lookupWitness(builder, witness, sharedContext->differentialAssocTypeStructKey); + if (as<IRDifferentialPairType>(type) || as<IRDifferentialPairUserCodeType>(type)) + return _lookupWitness( + builder, + witness, + sharedContext->differentialAssocTypeStructKey, + builder->getTypeKind()); + else if (as<IRDifferentialPtrPairType>(type)) + return _lookupWitness( + builder, + witness, + sharedContext->differentialAssocRefTypeStructKey, + builder->getTypeKind()); + else + SLANG_UNEXPECTED("Unexpected differential pair type"); } static IRInst* _getDiffTypeWitnessFromPairType(AutoDiffSharedContext* sharedContext, IRBuilder* builder, IRDifferentialPairTypeBase* type) { auto witnessTable = type->getWitness(); - return _lookupWitness(builder, witnessTable, sharedContext->differentialAssocTypeWitnessStructKey); + + if (as<IRDifferentialPairType>(type) || as<IRDifferentialPairUserCodeType>(type)) + return _lookupWitness( + builder, + witnessTable, + sharedContext->differentialAssocTypeWitnessStructKey, + sharedContext->differentialAssocTypeWitnessTableType); + else if (as<IRDifferentialPtrPairType>(type)) + return _lookupWitness( + builder, + witnessTable, + sharedContext->differentialAssocRefTypeWitnessStructKey, + sharedContext->differentialAssocRefTypeWitnessTableType); + else + SLANG_UNEXPECTED("Unexpected differential pair type"); } bool isNoDiffType(IRType* paramType) @@ -320,6 +353,24 @@ IRInst* DifferentialPairTypeBuilder::lowerDiffPairType( return result; } +IRInterfaceType* findDifferentiableRefInterface(IRModuleInst* moduleInst) +{ + for (auto inst : moduleInst->getGlobalInsts()) + { + if (auto interfaceType = as<IRInterfaceType>(inst)) + { + if (auto decor = interfaceType->findDecoration<IRNameHintDecoration>()) + { + if (decor->getName() == "IDifferentiablePtrType") + { + return interfaceType; + } + } + } + } + return nullptr; +} + AutoDiffSharedContext::AutoDiffSharedContext(TargetProgram* target, IRModuleInst* inModuleInst) : moduleInst(inModuleInst), targetProgram(target) { @@ -328,14 +379,27 @@ AutoDiffSharedContext::AutoDiffSharedContext(TargetProgram* target, IRModuleInst { differentialAssocTypeStructKey = findDifferentialTypeStructKey(); differentialAssocTypeWitnessStructKey = findDifferentialTypeWitnessStructKey(); + differentialAssocTypeWitnessTableType = findDifferentialTypeWitnessTableType(); zeroMethodStructKey = findZeroMethodStructKey(); + zeroMethodType = cast<IRFuncType>(getInterfaceEntryAtIndex(differentiableInterfaceType, 2)->getRequirementVal()); addMethodStructKey = findAddMethodStructKey(); + addMethodType = cast<IRFuncType>(getInterfaceEntryAtIndex(differentiableInterfaceType, 3)->getRequirementVal()); mulMethodStructKey = findMulMethodStructKey(); nullDifferentialStructType = findNullDifferentialStructType(); nullDifferentialWitness = findNullDifferentialWitness(); - if (differentialAssocTypeStructKey) - isInterfaceAvailable = true; + isInterfaceAvailable = true; + } + + differentiablePtrInterfaceType = as<IRInterfaceType>(findDifferentiableRefInterface(inModuleInst)); + + if (differentiablePtrInterfaceType) + { + differentialAssocRefTypeStructKey = findDifferentialPtrTypeStructKey(); + differentialAssocRefTypeWitnessStructKey = findDifferentialPtrTypeWitnessStructKey(); + differentialAssocRefTypeWitnessTableType = findDifferentialPtrTypeWitnessTableType(); + + isPtrInterfaceAvailable = true; } } @@ -404,14 +468,14 @@ IRInst* AutoDiffSharedContext::findNullDifferentialWitness() } -IRStructKey* AutoDiffSharedContext::getIDifferentiableStructKeyAtIndex(UInt index) +IRInterfaceRequirementEntry* AutoDiffSharedContext::getInterfaceEntryAtIndex(IRInterfaceType* interface, UInt index) { - if (as<IRModuleInst>(moduleInst) && differentiableInterfaceType) + if (as<IRModuleInst>(moduleInst) && interface) { // Assume for now that IDifferentiable has exactly five fields. - SLANG_ASSERT(differentiableInterfaceType->getOperandCount() == 5); - if (auto entry = as<IRInterfaceRequirementEntry>(differentiableInterfaceType->getOperand(index))) - return as<IRStructKey>(entry->getRequirementKey()); + // SLANG_ASSERT(interface->getOperandCount() == 5); + if (auto entry = as<IRInterfaceRequirementEntry>(interface->getOperand(index))) + return entry; else { SLANG_UNEXPECTED("IDifferentiable interface entry unexpected type"); @@ -421,6 +485,50 @@ IRStructKey* AutoDiffSharedContext::getIDifferentiableStructKeyAtIndex(UInt inde return nullptr; } +// Extracts conformance interface from a witness inst while accounting for some +// quirks in the type system around interfaces that conform to other interfaces. +// +IRInterfaceType* DifferentiableTypeConformanceContext::getConformanceTypeFromWitness(IRInst* witness) +{ + IRInterfaceType* diffInterfaceType = nullptr; + if (auto witnessTableType = as<IRWitnessTableType>(witness->getDataType())) + { + diffInterfaceType = cast<IRInterfaceType>(witnessTableType->getConformanceType()); + } + else if (auto structKey = as<IRStructKey>(witness)) + { + // We currently assume that a struct key is used uniquely for a single interface-requirement-entry. + // Find that entry + for (IRUse* use = structKey->firstUse; use; use = use->nextUse) + { + if (auto entry = as<IRInterfaceRequirementEntry>(use->getUser())) + { + auto innerWitnessTableType = cast<IRWitnessTableType>(entry->getRequirementVal()); + diffInterfaceType = cast<IRInterfaceType>(innerWitnessTableType->getConformanceType()); + break; + } + } + } + else if (auto interfaceRequirementEntry = as<IRInterfaceRequirementEntry>(witness)) + { + auto innerWitnessTableType = cast<IRWitnessTableType>(interfaceRequirementEntry->getRequirementVal()); + diffInterfaceType = cast<IRInterfaceType>(innerWitnessTableType->getConformanceType()); + } + else if (auto tupleType = as<IRTupleType>(witness->getDataType())) + { + SLANG_ASSERT(tupleType->getOperandCount() >= 1); + auto operand = tupleType->getOperand(0); + auto innerWitnessTableType = cast<IRWitnessTableType>(operand); + return cast<IRInterfaceType>(innerWitnessTableType->getConformanceType()); + } + else + { + SLANG_UNEXPECTED("Unexpected witness type"); + } + + return diffInterfaceType; +} + void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func) { parentFunc = func; @@ -434,7 +542,13 @@ void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func) { if (auto item = as<IRDifferentiableTypeDictionaryItem>(child)) { - auto existingItem = differentiableWitnessDictionary.tryGetValue(item->getConcreteType()); + IRInterfaceType* diffInterfaceType = getConformanceTypeFromWitness(item->getWitness()); + + SLANG_ASSERT( + diffInterfaceType == sharedContext->differentiableInterfaceType + || diffInterfaceType == sharedContext->differentiablePtrInterfaceType); + + auto existingItem = differentiableTypeWitnessDictionary.tryGetValue(item->getConcreteType()); if (existingItem) { *existingItem = item->getWitness(); @@ -458,20 +572,26 @@ void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func) { auto element = concreteType->getOperand(i); auto elementWitness = witnessPack->getOperand(i); - differentiableWitnessDictionary.addIfNotExists( - (IRType*)element, - _lookupWitness(&subBuilder, elementWitness, sharedContext->differentialAssocTypeStructKey)); + + if (diffInterfaceType == sharedContext->differentiableInterfaceType) + addTypeToDictionary( + (IRType*)element, + elementWitness); + else if (diffInterfaceType == sharedContext->differentiablePtrInterfaceType) + addTypeToDictionary( + (IRType*)element, + elementWitness); } return; } } - differentiableWitnessDictionary.add((IRType*)item->getConcreteType(), item->getWitness()); + addTypeToDictionary((IRType*)item->getConcreteType(), item->getWitness()); if (!as<IRInterfaceType>(item->getConcreteType())) { - differentiableWitnessDictionary.addIfNotExists( - (IRType*)_lookupWitness(&subBuilder, item->getWitness(), sharedContext->differentialAssocTypeStructKey), + addTypeToDictionary( + (IRType*)_lookupWitness(&subBuilder, item->getWitness(), sharedContext->differentialAssocTypeStructKey, subBuilder.getTypeKind()), item->getWitness()); } @@ -480,29 +600,55 @@ void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func) // For differential pair types, register the differential type as well. IRBuilder builder(diffPairType); builder.setInsertAfter(diffPairType->getWitness()); - auto diffType = _lookupWitness(&builder, diffPairType->getWitness(), sharedContext->differentialAssocTypeStructKey); - auto diffWitness = _lookupWitness(&builder, diffPairType->getWitness(), sharedContext->differentialAssocTypeWitnessStructKey); - if (diffType && diffWitness) - { - differentiableWitnessDictionary.addIfNotExists((IRType*)diffType, diffWitness); - } + + // TODO(sai): lot of this logic is duplicated. need to refactor. + auto diffType = (diffInterfaceType == sharedContext->differentiableInterfaceType) ? + _lookupWitness(&builder, diffPairType->getWitness(), sharedContext->differentialAssocTypeStructKey, builder.getTypeKind()) : + _lookupWitness(&builder, diffPairType->getWitness(), sharedContext->differentialAssocRefTypeStructKey, builder.getTypeKind()); + auto diffWitness = (diffInterfaceType == sharedContext->differentiableInterfaceType) ? + _lookupWitness( + &builder, + diffPairType->getWitness(), + sharedContext->differentialAssocTypeWitnessStructKey, + sharedContext->differentialAssocTypeWitnessTableType) : + _lookupWitness( + &builder, + diffPairType->getWitness(), + sharedContext->differentialAssocRefTypeWitnessStructKey, + sharedContext->differentialAssocRefTypeWitnessTableType); + + addTypeToDictionary((IRType*)diffType, diffWitness); } } } } } -IRInst* DifferentiableTypeConformanceContext::lookUpConformanceForType(IRInst* type) +IRInst* DifferentiableTypeConformanceContext::lookUpConformanceForType(IRInst* type, DiffConformanceKind kind) { IRInst* foundResult = nullptr; - differentiableWitnessDictionary.tryGetValue(type, foundResult); - return foundResult; + differentiableTypeWitnessDictionary.tryGetValue(type, foundResult); + if (!foundResult) + return nullptr; + + if (kind == DiffConformanceKind::Any) + return foundResult; + + if (auto baseType = getConformanceTypeFromWitness(foundResult)) + { + if (baseType == sharedContext->differentiableInterfaceType && kind == DiffConformanceKind::Value) + return foundResult; + else if (baseType == sharedContext->differentiablePtrInterfaceType && kind == DiffConformanceKind::Ptr) + return foundResult; + } + + return nullptr; } -IRInst* DifferentiableTypeConformanceContext::lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key) +IRInst* DifferentiableTypeConformanceContext::lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key, IRType* resultType) { - if (auto conformance = tryGetDifferentiableWitness(builder, origType)) - return _lookupWitness(builder, conformance, key); + if (auto conformance = tryGetDifferentiableWitness(builder, origType, DiffConformanceKind::Any)) + return _lookupWitness(builder, conformance, key, resultType); return nullptr; } @@ -514,7 +660,7 @@ IRInst* DifferentiableTypeConformanceContext::getDifferentialTypeFromDiffPairTyp IRInst* DifferentiableTypeConformanceContext::getDiffTypeFromPairType(IRBuilder* builder, IRDifferentialPairTypeBase* type) { - return _getDiffTypeFromPairType(sharedContext, builder, type); + return this->differentiateType(builder, type->getValueType()); } IRInst* DifferentiableTypeConformanceContext::getDiffTypeWitnessFromPairType(IRBuilder* builder, IRDifferentialPairTypeBase* type) @@ -525,20 +671,34 @@ IRInst* DifferentiableTypeConformanceContext::getDiffTypeWitnessFromPairType(IRB IRInst* DifferentiableTypeConformanceContext::getDiffZeroMethodFromPairType(IRBuilder* builder, IRDifferentialPairTypeBase* type) { auto witnessTable = type->getWitness(); - return _lookupWitness(builder, witnessTable, sharedContext->zeroMethodStructKey); + return _lookupWitness(builder, witnessTable, sharedContext->zeroMethodStructKey, sharedContext->zeroMethodType); } IRInst* DifferentiableTypeConformanceContext::getDiffAddMethodFromPairType(IRBuilder* builder, IRDifferentialPairTypeBase* type) { auto witnessTable = type->getWitness(); - return _lookupWitness(builder, witnessTable, sharedContext->addMethodStructKey); + return _lookupWitness(builder, witnessTable, sharedContext->addMethodStructKey, sharedContext->addMethodType); +} + +void DifferentiableTypeConformanceContext::addTypeToDictionary(IRType* type, IRInst* witness) +{ + auto conformanceType = getConformanceTypeFromWitness(witness); + + if (!sharedContext->isInterfaceAvailable && !sharedContext->isPtrInterfaceAvailable) + return; + + SLANG_ASSERT( + conformanceType == sharedContext->differentiableInterfaceType || + conformanceType == sharedContext->differentiablePtrInterfaceType); + + differentiableTypeWitnessDictionary.addIfNotExists(type, witness); } IRInst *DifferentiableTypeConformanceContext::tryExtractConformanceFromInterfaceType(IRBuilder *builder, IRInterfaceType *interfaceType, IRWitnessTable *witnessTable) { SLANG_RELEASE_ASSERT(interfaceType); - List<IRInterfaceRequirementEntry*> lookupKeyPath = findDifferentiableInterfaceLookupPath( + List<IRInterfaceRequirementEntry*> lookupKeyPath = findInterfaceLookupPath( sharedContext->differentiableInterfaceType, interfaceType); IRInst* differentialTypeWitness = witnessTable; @@ -549,6 +709,7 @@ IRInst *DifferentiableTypeConformanceContext::tryExtractConformanceFromInterface { differentialTypeWitness = builder->emitLookupInterfaceMethodInst((IRType*)node->getRequirementVal(), differentialTypeWitness, node->getRequirementKey()); // Lookup insts are always primal values. + builder->markInstAsPrimal(differentialTypeWitness); } return differentialTypeWitness; @@ -557,10 +718,10 @@ IRInst *DifferentiableTypeConformanceContext::tryExtractConformanceFromInterface return nullptr; } -// Given an interface type, return the lookup path from a witness table of `type` to a witness table of `IDifferentiable`. -static bool _findDifferentiableInterfaceLookupPathImpl( +// Given an interface type, return the lookup path from a witness table of `type` to a witness table of `supType`. +static bool _findInterfaceLookupPathImpl( HashSet<IRInst*>& processedTypes, - IRInterfaceType* idiffType, + IRInterfaceType* supType, IRInterfaceType* type, List<IRInterfaceRequirementEntry*>& currentPath) { @@ -576,13 +737,13 @@ static bool _findDifferentiableInterfaceLookupPathImpl( if (auto wt = as<IRWitnessTableTypeBase>(entry->getRequirementVal())) { currentPath.add(entry); - if (wt->getConformanceType() == idiffType) + if (wt->getConformanceType() == supType) { return true; } else if (auto subInterfaceType = as<IRInterfaceType>(wt->getConformanceType())) { - if (_findDifferentiableInterfaceLookupPathImpl(processedTypes, idiffType, subInterfaceType, currentPath)) + if (_findInterfaceLookupPathImpl(processedTypes, supType, subInterfaceType, currentPath)) return true; } currentPath.removeLast(); @@ -591,11 +752,11 @@ static bool _findDifferentiableInterfaceLookupPathImpl( return false; } -List<IRInterfaceRequirementEntry *> DifferentiableTypeConformanceContext::findDifferentiableInterfaceLookupPath(IRInterfaceType *idiffType, IRInterfaceType *type) +List<IRInterfaceRequirementEntry *> DifferentiableTypeConformanceContext::findInterfaceLookupPath(IRInterfaceType *supType, IRInterfaceType *type) { List<IRInterfaceRequirementEntry*> currentPath; HashSet<IRInst*> processedTypes; - _findDifferentiableInterfaceLookupPathImpl(processedTypes, idiffType, type, currentPath); + _findInterfaceLookupPathImpl(processedTypes, supType, type, currentPath); return currentPath; } @@ -722,7 +883,7 @@ void DifferentiableTypeConformanceContext::buildGlobalWitnessDictionary() { if (auto pairType = as<IRDifferentialPairTypeBase>(globalInst)) { - differentiableWitnessDictionary.addIfNotExists(pairType->getValueType(), pairType->getWitness()); + addTypeToDictionary(pairType->getValueType(), pairType->getWitness()); } } } @@ -762,9 +923,8 @@ IRType* DifferentiableTypeConformanceContext::differentiateType(IRBuilder* build case kIROp_DifferentialPairType: { auto primalPairType = as<IRDifferentialPairType>(primalType); - return getOrCreateDiffPairType( - builder, - getDiffTypeFromPairType(builder, primalPairType), + return builder->getDifferentialPairType( + (IRType*)getDiffTypeFromPairType(builder, primalPairType), getDiffTypeWitnessFromPairType(builder, primalPairType)); } @@ -776,6 +936,14 @@ IRType* DifferentiableTypeConformanceContext::differentiateType(IRBuilder* build getDiffTypeWitnessFromPairType(builder, primalPairType)); } + case kIROp_DifferentialPtrPairType: + { + auto primalPairType = as<IRDifferentialPtrPairType>(primalType); + return builder->getDifferentialPtrPairType( + (IRType*)getDiffTypeFromPairType(builder, primalPairType), + getDiffTypeWitnessFromPairType(builder, primalPairType)); + } + case kIROp_FuncType: { SLANG_UNIMPLEMENTED_X("Impl"); @@ -817,12 +985,12 @@ IRType* DifferentiableTypeConformanceContext::differentiateType(IRBuilder* build } } -IRInst* DifferentiableTypeConformanceContext::tryGetDifferentiableWitness(IRBuilder* builder, IRInst* primalType) +IRInst* DifferentiableTypeConformanceContext::tryGetDifferentiableWitness(IRBuilder* builder, IRInst* primalType, DiffConformanceKind kind) { if (isNoDiffType((IRType*)primalType)) return nullptr; - - IRInst* witness = lookUpConformanceForType((IRType*)primalType); + + IRInst* witness = lookUpConformanceForType((IRType*)primalType, kind); if (witness) { SLANG_RELEASE_ASSERT(witness || as<IRArrayType>(primalType)); @@ -834,31 +1002,60 @@ IRInst* DifferentiableTypeConformanceContext::tryGetDifferentiableWitness(IRBuil witness = nullptr; } - if (!witness) + if (witness) + return witness; + + // If a witness is not already mapped, build one if possible. + SLANG_RELEASE_ASSERT(primalType); + if (auto primalPairType = as<IRDifferentialPairTypeBase>(primalType)) { - SLANG_RELEASE_ASSERT(primalType); - if (auto primalPairType = as<IRDifferentialPairTypeBase>(primalType)) - { - witness = getOrCreateDifferentiablePairWitness(builder, primalPairType); - } - else if (auto arrayType = as<IRArrayType>(primalType)) - { - witness = getArrayWitness(builder, arrayType); - } - else if (auto extractExistential = as<IRExtractExistentialType>(primalType)) - { - witness = getExtractExistensialTypeWitness(builder, extractExistential); - } - else if (auto typePack = as<IRTypePack>(primalType)) + witness = buildDifferentiablePairWitness(builder, primalPairType, kind); + } + else if (auto arrayType = as<IRArrayType>(primalType)) + { + witness = buildArrayWitness(builder, arrayType, kind); + } + else if (auto extractExistential = as<IRExtractExistentialType>(primalType)) + { + witness = buildExtractExistensialTypeWitness(builder, extractExistential, kind); + } + else if (auto typePack = as<IRTypePack>(primalType)) + { + witness = buildTupleWitness(builder, typePack, kind); + } + else if (auto tupleType = as<IRTupleType>(primalType)) + { + witness = buildTupleWitness(builder, tupleType, kind); + } + else if (auto lookup = as<IRLookupWitnessMethod>(primalType)) + { + // For types that are lookups from a table, we can simply lookup the witness from the same table + if (lookup->getRequirementKey() == sharedContext->differentialAssocTypeStructKey) { - witness = getTupleWitness(builder, typePack); + witness = builder->emitLookupInterfaceMethodInst( + lookup->getWitnessTable()->getDataType(), + lookup->getWitnessTable(), + sharedContext->differentialAssocTypeWitnessStructKey); } - else if (auto tupleType = as<IRTupleType>(primalType)) + + if (lookup->getRequirementKey() == sharedContext->differentialAssocRefTypeStructKey) { - witness = getTupleWitness(builder, tupleType); + witness = builder->emitLookupInterfaceMethodInst( + lookup->getWitnessTable()->getDataType(), + lookup->getWitnessTable(), + sharedContext->differentialAssocRefTypeWitnessStructKey); } } - return witness; + + // If we created a witness, register it. + if (witness) + { + addTypeToDictionary((IRType*)primalType, witness); + return witness; + } + + // Failed. Type is either non-differentiable, or unhandled. + return nullptr; } IRType* DifferentiableTypeConformanceContext::getOrCreateDiffPairType(IRBuilder* builder, IRInst* primalType, IRInst* witness) @@ -868,77 +1065,97 @@ IRType* DifferentiableTypeConformanceContext::getOrCreateDiffPairType(IRBuilder* witness); } -IRInst* DifferentiableTypeConformanceContext::getOrCreateDifferentiablePairWitness(IRBuilder* builder, IRDifferentialPairTypeBase* pairType) +IRInst* DifferentiableTypeConformanceContext::buildDifferentiablePairWitness( + IRBuilder* builder, + IRDifferentialPairTypeBase* pairType, + DiffConformanceKind target) { - // Differentiate the pair type to get it's differential (which is itself a pair) - auto diffDiffPairType = (IRType*)differentiateType(builder, (IRType*)pairType); - - auto addMethod = builder->createFunc(); - auto zeroMethod = builder->createFunc(); - - auto table = builder->createWitnessTable(this->sharedContext->differentiableInterfaceType, (IRType*)pairType); - - // And place it in the synthesized witness table. - builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeStructKey, diffDiffPairType); - builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeWitnessStructKey, table); - builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod); - builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod); - - bool isUserCodeType = as<IRDifferentialPairUserCodeType>(pairType) ? true : false; - - // Fill in differential method implementations. - auto elementType = as<IRDifferentialPairTypeBase>(pairType)->getValueType(); - auto innerWitness = as<IRDifferentialPairTypeBase>(pairType)->getWitness(); - - { - // Add method. - IRBuilder b = *builder; - b.setInsertInto(addMethod); - b.addBackwardDifferentiableDecoration(addMethod); - IRType* paramTypes[2] = { diffDiffPairType, diffDiffPairType }; - addMethod->setFullType(b.getFuncType(2, paramTypes, diffDiffPairType)); - b.emitBlock(); - auto p0 = b.emitParam(diffDiffPairType); - auto p1 = b.emitParam(diffDiffPairType); - - // Since we are already dealing with a DiffPair<T>.Differnetial type, we know that value type == diff type. - auto innerAdd = _lookupWitness(&b, innerWitness, sharedContext->addMethodStructKey); - IRInst* argsPrimal[2] = { - isUserCodeType ? b.emitDifferentialPairGetPrimalUserCode(p0) : b.emitDifferentialPairGetPrimal(p0), - isUserCodeType ? b.emitDifferentialPairGetPrimalUserCode(p1) : b.emitDifferentialPairGetPrimal(p1) }; - auto primalPart = b.emitCallInst(elementType, innerAdd, 2, argsPrimal); - IRInst* argsDiff[2] = { - isUserCodeType ? b.emitDifferentialPairGetDifferentialUserCode(elementType, p0) : b.emitDifferentialPairGetDifferential(elementType, p0), - isUserCodeType ? b.emitDifferentialPairGetDifferentialUserCode(elementType, p1) : b.emitDifferentialPairGetDifferential(elementType, p1)}; - auto diffPart = b.emitCallInst(elementType, innerAdd, 2, argsDiff); - auto retVal = - isUserCodeType - ? b.emitMakeDifferentialPairUserCode(diffDiffPairType, primalPart, diffPart) - : b.emitMakeDifferentialPair(diffDiffPairType, primalPart, diffPart); - b.emitReturn(retVal); - } - { - // Zero method. - IRBuilder b = *builder; - b.setInsertInto(zeroMethod); - zeroMethod->setFullType(b.getFuncType(0, nullptr, diffDiffPairType)); - b.emitBlock(); - auto innerZero = _lookupWitness(&b, innerWitness, sharedContext->zeroMethodStructKey); - auto zeroVal = b.emitCallInst(elementType, innerZero, 0, nullptr); - auto retVal = - isUserCodeType - ? b.emitMakeDifferentialPairUserCode(diffDiffPairType, zeroVal, zeroVal) - : b.emitMakeDifferentialPair(diffDiffPairType, zeroVal, zeroVal); - b.emitReturn(retVal); + IRWitnessTable* table = nullptr; + if (target == DiffConformanceKind::Value) + { + // Differentiate the pair type to get it's differential (which is itself a pair) + auto diffDiffPairType = (IRType*)differentiateType(builder, (IRType*)pairType); + + auto addMethod = builder->createFunc(); + auto zeroMethod = builder->createFunc(); + + table = builder->createWitnessTable(sharedContext->differentiableInterfaceType, (IRType*)pairType); + + // And place it in the synthesized witness table. + builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeStructKey, diffDiffPairType); + builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeWitnessStructKey, table); + builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod); + builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod); + + bool isUserCodeType = as<IRDifferentialPairUserCodeType>(pairType) ? true : false; + + // Fill in differential method implementations. + auto elementType = as<IRDifferentialPairTypeBase>(pairType)->getValueType(); + auto innerWitness = as<IRDifferentialPairTypeBase>(pairType)->getWitness(); + + { + // Add method. + IRBuilder b = *builder; + b.setInsertInto(addMethod); + b.addBackwardDifferentiableDecoration(addMethod); + IRType* paramTypes[2] = { diffDiffPairType, diffDiffPairType }; + addMethod->setFullType(b.getFuncType(2, paramTypes, diffDiffPairType)); + b.emitBlock(); + auto p0 = b.emitParam(diffDiffPairType); + auto p1 = b.emitParam(diffDiffPairType); + + // Since we are already dealing with a DiffPair<T>.Differnetial type, we know that value type == diff type. + auto innerAdd = _lookupWitness(&b, innerWitness, sharedContext->addMethodStructKey, sharedContext->addMethodType); + IRInst* argsPrimal[2] = { + isUserCodeType ? b.emitDifferentialPairGetPrimalUserCode(p0) : b.emitDifferentialPairGetPrimal(p0), + isUserCodeType ? b.emitDifferentialPairGetPrimalUserCode(p1) : b.emitDifferentialPairGetPrimal(p1) }; + auto primalPart = b.emitCallInst(elementType, innerAdd, 2, argsPrimal); + IRInst* argsDiff[2] = { + isUserCodeType ? b.emitDifferentialPairGetDifferentialUserCode(elementType, p0) : b.emitDifferentialPairGetDifferential(elementType, p0), + isUserCodeType ? b.emitDifferentialPairGetDifferentialUserCode(elementType, p1) : b.emitDifferentialPairGetDifferential(elementType, p1)}; + auto diffPart = b.emitCallInst(elementType, innerAdd, 2, argsDiff); + auto retVal = + isUserCodeType + ? b.emitMakeDifferentialPairUserCode(diffDiffPairType, primalPart, diffPart) + : b.emitMakeDifferentialPair(diffDiffPairType, primalPart, diffPart); + b.emitReturn(retVal); + } + { + // Zero method. + IRBuilder b = *builder; + b.setInsertInto(zeroMethod); + zeroMethod->setFullType(b.getFuncType(0, nullptr, diffDiffPairType)); + b.emitBlock(); + auto innerZero = _lookupWitness(&b, innerWitness, sharedContext->zeroMethodStructKey, sharedContext->zeroMethodType); + auto zeroVal = b.emitCallInst(elementType, innerZero, 0, nullptr); + auto retVal = + isUserCodeType + ? b.emitMakeDifferentialPairUserCode(diffDiffPairType, zeroVal, zeroVal) + : b.emitMakeDifferentialPair(diffDiffPairType, zeroVal, zeroVal); + b.emitReturn(retVal); + } + } + else if (target == DiffConformanceKind::Ptr) + { + // Differentiate the pair type to get it's differential (which is itself a pair) + auto diffDiffPairType = (IRType*)differentiateType(builder, (IRType*)pairType); + + table = builder->createWitnessTable( + sharedContext->differentiablePtrInterfaceType, + (IRType*)pairType); + + // And place it in the synthesized witness table. + builder->createWitnessTableEntry(table, sharedContext->differentialAssocRefTypeStructKey, diffDiffPairType); + builder->createWitnessTableEntry(table, sharedContext->differentialAssocRefTypeWitnessStructKey, table); } - - // Record this in the context for future lookups - differentiableWitnessDictionary[(IRType*)pairType] = table; return table; } -IRInst* DifferentiableTypeConformanceContext::getArrayWitness(IRBuilder* builder, IRArrayType* arrayType) +IRInst* DifferentiableTypeConformanceContext::buildArrayWitness( + IRBuilder* builder, + IRArrayType* arrayType, + DiffConformanceKind target) { // Differentiate the pair type to get it's differential (which is itself a pair) auto diffArrayType = (IRType*)differentiateType(builder, (IRType*)arrayType); @@ -946,70 +1163,89 @@ IRInst* DifferentiableTypeConformanceContext::getArrayWitness(IRBuilder* builder if (!diffArrayType) return nullptr; - auto innerWitness = tryGetDifferentiableWitness(builder, as<IRArrayTypeBase>(arrayType)->getElementType()); + IRWitnessTable* table = nullptr; + if (target == DiffConformanceKind::Value) + { + SLANG_ASSERT(isDifferentiableValueType((IRType*)arrayType)); + auto innerWitness = tryGetDifferentiableWitness(builder, as<IRArrayTypeBase>(arrayType)->getElementType(), DiffConformanceKind::Value); - auto addMethod = builder->createFunc(); - auto zeroMethod = builder->createFunc(); + auto addMethod = builder->createFunc(); + auto zeroMethod = builder->createFunc(); - auto table = builder->createWitnessTable(sharedContext->differentiableInterfaceType, (IRType*)arrayType); + table = builder->createWitnessTable(sharedContext->differentiableInterfaceType, (IRType*)arrayType); - // And place it in the synthesized witness table. - builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeStructKey, diffArrayType); - builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeWitnessStructKey, table); - builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod); - builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod); + // And place it in the synthesized witness table. + builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeStructKey, diffArrayType); + builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeWitnessStructKey, table); + builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod); + builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod); - auto elementType = as<IRArrayTypeBase>(diffArrayType)->getElementType(); + auto elementType = as<IRArrayTypeBase>(diffArrayType)->getElementType(); - // Fill in differential method implementations. + // Fill in differential method implementations. + { + // Add method. + IRBuilder b = *builder; + b.setInsertInto(addMethod); + b.addBackwardDifferentiableDecoration(addMethod); + IRType* paramTypes[2] = { diffArrayType, diffArrayType }; + addMethod->setFullType(b.getFuncType(2, paramTypes, diffArrayType)); + b.emitBlock(); + auto p0 = b.emitParam(diffArrayType); + auto p1 = b.emitParam(diffArrayType); + + // Since we are already dealing with a DiffPair<T>.Differnetial type, we know that value type == diff type. + auto innerAdd = _lookupWitness(&b, innerWitness, sharedContext->addMethodStructKey, sharedContext->addMethodType); + auto resultVar = b.emitVar(diffArrayType); + IRBlock* loopBodyBlock = nullptr; + IRBlock* loopBreakBlock = nullptr; + auto loopCounter = emitLoopBlocks(&b, b.getIntValue(b.getIntType(), 0), as<IRArrayTypeBase>(diffArrayType)->getElementCount(), loopBodyBlock, loopBreakBlock); + b.setInsertBefore(loopBodyBlock->getTerminator()); + + IRInst* args[2] = { + b.emitElementExtract(p0, loopCounter), + b.emitElementExtract(p1, loopCounter) }; + auto elementResult = b.emitCallInst(elementType, innerAdd, 2, args); + auto addr = b.emitElementAddress(resultVar, loopCounter); + b.emitStore(addr, elementResult); + b.setInsertInto(loopBreakBlock); + b.emitReturn(b.emitLoad(resultVar)); + } + { + // Zero method. + IRBuilder b = *builder; + b.setInsertInto(zeroMethod); + zeroMethod->setFullType(b.getFuncType(0, nullptr, diffArrayType)); + b.emitBlock(); + + auto innerZero = _lookupWitness(&b, innerWitness, sharedContext->zeroMethodStructKey, sharedContext->zeroMethodType); + auto zeroVal = b.emitCallInst(elementType, innerZero, 0, nullptr); + auto retVal = b.emitMakeArrayFromElement(diffArrayType, zeroVal); + b.emitReturn(retVal); + } + } + else if (target == DiffConformanceKind::Ptr) { - // Add method. - IRBuilder b = *builder; - b.setInsertInto(addMethod); - b.addBackwardDifferentiableDecoration(addMethod); - IRType* paramTypes[2] = { diffArrayType, diffArrayType }; - addMethod->setFullType(b.getFuncType(2, paramTypes, diffArrayType)); - b.emitBlock(); - auto p0 = b.emitParam(diffArrayType); - auto p1 = b.emitParam(diffArrayType); + SLANG_ASSERT(isDifferentiablePtrType((IRType*)arrayType)); - // Since we are already dealing with a DiffPair<T>.Differnetial type, we know that value type == diff type. - auto innerAdd = _lookupWitness(&b, innerWitness, sharedContext->addMethodStructKey); - auto resultVar = b.emitVar(diffArrayType); - IRBlock* loopBodyBlock = nullptr; - IRBlock* loopBreakBlock = nullptr; - auto loopCounter = emitLoopBlocks(&b, b.getIntValue(b.getIntType(), 0), as<IRArrayTypeBase>(diffArrayType)->getElementCount(), loopBodyBlock, loopBreakBlock); - b.setInsertBefore(loopBodyBlock->getTerminator()); + table = builder->createWitnessTable(sharedContext->differentiablePtrInterfaceType, (IRType*)arrayType); - IRInst* args[2] = { - b.emitElementExtract(p0, loopCounter), - b.emitElementExtract(p1, loopCounter) }; - auto elementResult = b.emitCallInst(elementType, innerAdd, 2, args); - auto addr = b.emitElementAddress(resultVar, loopCounter); - b.emitStore(addr, elementResult); - b.setInsertInto(loopBreakBlock); - b.emitReturn(b.emitLoad(resultVar)); + // And place it in the synthesized witness table. + builder->createWitnessTableEntry(table, sharedContext->differentialAssocRefTypeStructKey, diffArrayType); + builder->createWitnessTableEntry(table, sharedContext->differentialAssocRefTypeWitnessStructKey, table); } + else { - // Zero method. - IRBuilder b = *builder; - b.setInsertInto(zeroMethod); - zeroMethod->setFullType(b.getFuncType(0, nullptr, diffArrayType)); - b.emitBlock(); - - auto innerZero = _lookupWitness(&b, innerWitness, sharedContext->zeroMethodStructKey); - auto zeroVal = b.emitCallInst(elementType, innerZero, 0, nullptr); - auto retVal = b.emitMakeArrayFromElement(diffArrayType, zeroVal); - b.emitReturn(retVal); + SLANG_UNEXPECTED("Invalid conformance kind for synthesis"); } - // Record this in the context for future lookups - differentiableWitnessDictionary[(IRType*)arrayType] = table; - return table; } -IRInst* DifferentiableTypeConformanceContext::getTupleWitness(IRBuilder* builder, IRInst* inTupleType) +IRInst* DifferentiableTypeConformanceContext::buildTupleWitness( + IRBuilder* builder, + IRInst* inTupleType, + DiffConformanceKind target) { // Differentiate the pair type to get it's differential (which is itself a pair) auto diffTupleType = (IRType*)differentiateType(builder, (IRType*)inTupleType); @@ -1017,100 +1253,116 @@ IRInst* DifferentiableTypeConformanceContext::getTupleWitness(IRBuilder* builder if (!diffTupleType) return nullptr; - auto addMethod = builder->createFunc(); - auto zeroMethod = builder->createFunc(); - - auto table = builder->createWitnessTable(sharedContext->differentiableInterfaceType, (IRType*)inTupleType); - - // And place it in the synthesized witness table. - builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeStructKey, diffTupleType); - builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeWitnessStructKey, table); - builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod); - builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod); - - // Fill in differential method implementations. - { - // Add method. - IRBuilder b = *builder; - b.setInsertInto(addMethod); - b.addBackwardDifferentiableDecoration(addMethod); - IRType* paramTypes[2] = { diffTupleType, diffTupleType }; - addMethod->setFullType(b.getFuncType(2, paramTypes, diffTupleType)); - b.emitBlock(); - auto p0 = b.emitParam(diffTupleType); - auto p1 = b.emitParam(diffTupleType); - List<IRInst*> results; - for (UInt i = 0; i < inTupleType->getOperandCount(); i++) - { - auto elementType = inTupleType->getOperand(i); - auto diffElementType = (IRType*)diffTupleType->getOperand(i); - auto innerWitness = tryGetDifferentiableWitness(&b, (IRType*)elementType); - IRInst* elementResult = nullptr; - if (!innerWitness) + IRWitnessTable* table = nullptr; + if (target == DiffConformanceKind::Value) + { + SLANG_ASSERT(isDifferentiableValueType((IRType*)inTupleType)); + + auto addMethod = builder->createFunc(); + auto zeroMethod = builder->createFunc(); + + table = builder->createWitnessTable(sharedContext->differentiableInterfaceType, (IRType*)inTupleType); + + // And place it in the synthesized witness table. + builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeStructKey, diffTupleType); + builder->createWitnessTableEntry(table, sharedContext->differentialAssocTypeWitnessStructKey, table); + builder->createWitnessTableEntry(table, sharedContext->addMethodStructKey, addMethod); + builder->createWitnessTableEntry(table, sharedContext->zeroMethodStructKey, zeroMethod); + + // Fill in differential method implementations. + { + // Add method. + IRBuilder b = *builder; + b.setInsertInto(addMethod); + b.addBackwardDifferentiableDecoration(addMethod); + IRType* paramTypes[2] = { diffTupleType, diffTupleType }; + addMethod->setFullType(b.getFuncType(2, paramTypes, diffTupleType)); + b.emitBlock(); + auto p0 = b.emitParam(diffTupleType); + auto p1 = b.emitParam(diffTupleType); + List<IRInst*> results; + for (UInt i = 0; i < inTupleType->getOperandCount(); i++) { - elementResult = b.getVoidValue(); + auto elementType = inTupleType->getOperand(i); + auto diffElementType = (IRType*)diffTupleType->getOperand(i); + auto innerWitness = tryGetDifferentiableWitness(&b, (IRType*)elementType, DiffConformanceKind::Value); + IRInst* elementResult = nullptr; + if (!innerWitness) + { + elementResult = b.getVoidValue(); + } + else + { + auto innerAdd = _lookupWitness(&b, innerWitness, sharedContext->addMethodStructKey, sharedContext->addMethodType); + auto iVal = b.getIntValue(b.getIntType(), i); + IRInst* args[2] = { + b.emitGetTupleElement(diffElementType, p0, iVal), + b.emitGetTupleElement(diffElementType, p1, iVal) }; + elementResult = b.emitCallInst(diffElementType, innerAdd, 2, args); + } + results.add(elementResult); } + IRInst* resultVal = nullptr; + if (diffTupleType->getOp() == kIROp_TupleType) + resultVal = b.emitMakeTuple(diffTupleType, results); else - { - auto innerAdd = _lookupWitness(&b, innerWitness, sharedContext->addMethodStructKey); - auto iVal = b.getIntValue(b.getIntType(), i); - IRInst* args[2] = { - b.emitGetTupleElement(diffElementType, p0, iVal), - b.emitGetTupleElement(diffElementType, p1, iVal) }; - elementResult = b.emitCallInst(diffElementType, innerAdd, 2, args); - } - results.add(elementResult); + resultVal = b.emitMakeValuePack(diffTupleType, (UInt)results.getCount(), results.getBuffer()); + b.emitReturn(resultVal); } - IRInst* resultVal = nullptr; - if (diffTupleType->getOp() == kIROp_TupleType) - resultVal = b.emitMakeTuple(diffTupleType, results); - else - resultVal = b.emitMakeValuePack(diffTupleType, (UInt)results.getCount(), results.getBuffer()); - b.emitReturn(resultVal); - } - { - // Zero method. - IRBuilder b = *builder; - b.setInsertInto(addMethod); - b.addBackwardDifferentiableDecoration(addMethod); - addMethod->setFullType(b.getFuncType(0, nullptr, diffTupleType)); - b.emitBlock(); - List<IRInst*> results; - for (UInt i = 0; i < inTupleType->getOperandCount(); i++) - { - auto elementType = inTupleType->getOperand(i); - auto diffElementType = (IRType*)diffTupleType->getOperand(i); - auto innerWitness = tryGetDifferentiableWitness(&b, (IRType*)elementType); - IRInst* elementResult = nullptr; - if (!innerWitness) + { + // Zero method. + IRBuilder b = *builder; + b.setInsertInto(addMethod); + b.addBackwardDifferentiableDecoration(addMethod); + addMethod->setFullType(b.getFuncType(0, nullptr, diffTupleType)); + b.emitBlock(); + List<IRInst*> results; + for (UInt i = 0; i < inTupleType->getOperandCount(); i++) { - elementResult = b.getVoidValue(); + auto elementType = inTupleType->getOperand(i); + auto diffElementType = (IRType*)diffTupleType->getOperand(i); + auto innerWitness = tryGetDifferentiableWitness(&b, (IRType*)elementType, DiffConformanceKind::Value); + IRInst* elementResult = nullptr; + if (!innerWitness) + { + elementResult = b.getVoidValue(); + } + else + { + auto innerZero = _lookupWitness(&b, innerWitness, sharedContext->zeroMethodStructKey, sharedContext->zeroMethodType); + elementResult = b.emitCallInst(diffElementType, innerZero, 0, nullptr); + } + results.add(elementResult); } + IRInst* resultVal = nullptr; + if (diffTupleType->getOp() == kIROp_TupleType) + resultVal = b.emitMakeTuple(diffTupleType, results); else - { - auto innerZero = _lookupWitness(&b, innerWitness, sharedContext->zeroMethodStructKey); - elementResult = b.emitCallInst(diffElementType, innerZero, 0, nullptr); - } - results.add(elementResult); + resultVal = b.emitMakeValuePack(diffTupleType, (UInt)results.getCount(), results.getBuffer()); + b.emitReturn(resultVal); } - IRInst* resultVal = nullptr; - if (diffTupleType->getOp() == kIROp_TupleType) - resultVal = b.emitMakeTuple(diffTupleType, results); - else - resultVal = b.emitMakeValuePack(diffTupleType, (UInt)results.getCount(), results.getBuffer()); - b.emitReturn(resultVal); } + else if (target == DiffConformanceKind::Ptr) + { + SLANG_ASSERT(isDifferentiablePtrType((IRType*)inTupleType)); - // Record this in the context for future lookups - differentiableWitnessDictionary[(IRType*)inTupleType] = table; + table = builder->createWitnessTable(sharedContext->differentiablePtrInterfaceType, (IRType*)inTupleType); + + // And place it in the synthesized witness table. + builder->createWitnessTableEntry(table, sharedContext->differentialAssocRefTypeStructKey, diffTupleType); + builder->createWitnessTableEntry(table, sharedContext->differentialAssocRefTypeWitnessStructKey, table); + } return table; } -IRInst* DifferentiableTypeConformanceContext::getExtractExistensialTypeWitness( +IRInst* DifferentiableTypeConformanceContext::buildExtractExistensialTypeWitness( IRBuilder* builder, - IRExtractExistentialType* extractExistentialType) + IRExtractExistentialType* extractExistentialType, + DiffConformanceKind target) { + SLANG_UNUSED(target); // logic is the same for both value and ptr + // Check that the type's base is differentiable if (differentiateType(builder, extractExistentialType->getOperand(0)->getDataType())) { @@ -1310,12 +1562,13 @@ bool isDifferentiableType(DifferentiableTypeConformanceContext& context, IRInst* if (context.isDifferentiableType((IRType*)typeInst)) return true; + // Look for equivalent types. - for (auto type : context.differentiableWitnessDictionary) + for (auto type : context.differentiableTypeWitnessDictionary) { if (isTypeEqual(type.key, (IRType*)typeInst)) { - context.differentiableWitnessDictionary[(IRType*)typeInst] = type.value; + context.differentiableTypeWitnessDictionary[(IRType*)typeInst] = type.value; return true; } } @@ -1672,7 +1925,7 @@ struct AutoDiffPass : public InstPassBase IRBuilder keyBuilder = builder; keyBuilder.setInsertBefore(maybeFindOuterGeneric(originalType)); auto diffKey = keyBuilder.createStructKey(); - auto diffFieldType = _lookupWitness(&keyBuilder, diffFieldWitness, autodiffContext->differentialAssocTypeStructKey); + auto diffFieldType = _lookupWitness(&keyBuilder, diffFieldWitness, autodiffContext->differentialAssocTypeStructKey, builder.getTypeKind()); info.field = builder.createStructField(diffType, diffKey, (IRType*)diffFieldType); info.witness = diffFieldWitness; builder.addDecoration(field->getKey(), kIROp_DerivativeMemberDecoration, diffKey); @@ -1695,7 +1948,11 @@ struct AutoDiffPass : public InstPassBase List<IRInst*> fieldVals; for (auto info : diffFields) { - auto innerZeroMethod = _lookupWitness(&builder, info.witness, autodiffContext->zeroMethodStructKey); + auto innerZeroMethod = _lookupWitness( + &builder, + info.witness, + autodiffContext->zeroMethodStructKey, + autodiffContext->zeroMethodType); IRInst* val = builder.emitCallInst(info.field->getFieldType(), innerZeroMethod, 0, nullptr); fieldVals.add(val); } @@ -1719,7 +1976,11 @@ struct AutoDiffPass : public InstPassBase List<IRInst*> fieldVals; for (auto info : diffFields) { - auto innerAddMethod = _lookupWitness(&builder, info.witness, autodiffContext->addMethodStructKey); + auto innerAddMethod = _lookupWitness( + &builder, + info.witness, + autodiffContext->addMethodStructKey, + autodiffContext->addMethodType); IRInst* args[2] = { builder.emitFieldExtract(info.field->getFieldType(), param1, info.field->getKey()), builder.emitFieldExtract(info.field->getFieldType(), param2, info.field->getKey()), diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h index 812471fe3..ad2486aad 100644 --- a/source/slang/slang-ir-autodiff.h +++ b/source/slang/slang-ir-autodiff.h @@ -57,6 +57,14 @@ struct DiffTranscriberSet AutoDiffTranscriberBase* backwardTranscriber = nullptr; }; + +enum class DiffConformanceKind +{ + Any = 0, // Perform actions for any conformance (infer from context) + Ptr = 1, // Perform actions for IDifferentiablePtrType + Value = 2 // Perform actions for IDifferentiable +}; + struct AutoDiffSharedContext { TargetProgram* targetProgram = nullptr; @@ -78,6 +86,7 @@ struct AutoDiffSharedContext // The struct key for the witness that `Differential` associated type conforms to // `IDifferential`. IRStructKey* differentialAssocTypeWitnessStructKey = nullptr; + IRWitnessTableType* differentialAssocTypeWitnessTableType = nullptr; // The struct key for the 'zero()' associated type @@ -85,12 +94,14 @@ struct AutoDiffSharedContext // implementation of zero() for a given type. // IRStructKey* zeroMethodStructKey = nullptr; + IRFuncType* zeroMethodType = nullptr; // The struct key for the 'add()' associated type // defined inside IDifferential. We use this to lookup the // implementation of add() for a given type. // IRStructKey* addMethodStructKey = nullptr; + IRFuncType* addMethodType = nullptr; IRStructKey* mulMethodStructKey = nullptr; @@ -104,12 +115,27 @@ struct AutoDiffSharedContext // IRInst* nullDifferentialWitness = nullptr; + + // A reference to the builtin IDifferentiablePtrType interface type. + IRInterfaceType* differentiablePtrInterfaceType = nullptr; + + // The struct key for the 'Differential' associated type + // defined inside IDifferentialPtrType. We use this to lookup the differential + // type in the conformance table associated with the concrete type. + // + IRStructKey* differentialAssocRefTypeStructKey = nullptr; + + // The struct key for the witness that `Differential` associated type conforms to + // `IDifferentialPtrType`. + IRStructKey* differentialAssocRefTypeWitnessStructKey = nullptr; + IRWitnessTableType* differentialAssocRefTypeWitnessTableType = nullptr; // Modules that don't use differentiable types // won't have the IDifferentiable interface type available. // Set to false to indicate that we are uninitialized. // bool isInterfaceAvailable = false; + bool isPtrInterfaceAvailable = false; List<FuncBodyTranscriptionTask> followUpFunctionsToTranscribe; @@ -127,38 +153,70 @@ private: IRStructKey* findDifferentialTypeStructKey() { - return getIDifferentiableStructKeyAtIndex(0); + return cast<IRStructKey>( + getInterfaceEntryAtIndex(differentiableInterfaceType, 0)->getRequirementKey()); } IRStructKey* findDifferentialTypeWitnessStructKey() { - return getIDifferentiableStructKeyAtIndex(1); + return cast<IRStructKey>( + getInterfaceEntryAtIndex(differentiableInterfaceType, 1)->getRequirementKey()); + } + + IRWitnessTableType* findDifferentialTypeWitnessTableType() + { + return cast<IRWitnessTableType>( + getInterfaceEntryAtIndex(differentiableInterfaceType, 1)->getRequirementVal()); } IRStructKey* findZeroMethodStructKey() { - return getIDifferentiableStructKeyAtIndex(2); + return cast<IRStructKey>( + getInterfaceEntryAtIndex(differentiableInterfaceType, 2)->getRequirementKey()); } IRStructKey* findAddMethodStructKey() { - return getIDifferentiableStructKeyAtIndex(3); + return cast<IRStructKey>( + getInterfaceEntryAtIndex(differentiableInterfaceType, 3)->getRequirementKey()); } IRStructKey* findMulMethodStructKey() { - return getIDifferentiableStructKeyAtIndex(4); + return cast<IRStructKey>( + getInterfaceEntryAtIndex(differentiableInterfaceType, 4)->getRequirementKey()); + } + + + IRStructKey* findDifferentialPtrTypeStructKey() + { + return cast<IRStructKey>( + getInterfaceEntryAtIndex(differentiablePtrInterfaceType, 0)->getRequirementKey()); } - IRStructKey* getIDifferentiableStructKeyAtIndex(UInt index); + IRStructKey* findDifferentialPtrTypeWitnessStructKey() + { + return cast<IRStructKey>( + getInterfaceEntryAtIndex(differentiablePtrInterfaceType, 1)->getRequirementKey()); + } + + IRWitnessTableType* findDifferentialPtrTypeWitnessTableType() + { + return cast<IRWitnessTableType>( + getInterfaceEntryAtIndex(differentiablePtrInterfaceType, 1)->getRequirementVal()); + } + + //IRStructKey* getIDifferentiableStructKeyAtIndex(UInt index); + IRInterfaceRequirementEntry* getInterfaceEntryAtIndex(IRInterfaceType* interface, UInt index); }; + struct DifferentiableTypeConformanceContext { AutoDiffSharedContext* sharedContext; IRGlobalValueWithCode* parentFunc = nullptr; - OrderedDictionary<IRType*, IRInst*> differentiableWitnessDictionary; + OrderedDictionary<IRType*, IRInst*> differentiableTypeWitnessDictionary; IRFunc* existentialDAddFunc = nullptr; @@ -167,7 +225,7 @@ struct DifferentiableTypeConformanceContext { // Populate dictionary with null differential type. if (sharedContext->nullDifferentialStructType) - differentiableWitnessDictionary.add( + differentiableTypeWitnessDictionary.add( sharedContext->nullDifferentialStructType, sharedContext->nullDifferentialWitness); } @@ -179,21 +237,13 @@ struct DifferentiableTypeConformanceContext // Lookup a witness table for the concreteType. One should exist if concreteType // inherits (successfully) from IDifferentiable. // - IRInst* lookUpConformanceForType(IRInst* type); + IRInst* lookUpConformanceForType(IRInst* type, DiffConformanceKind kind); - IRInst* lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key); + IRInst* lookUpInterfaceMethod(IRBuilder* builder, IRType* origType, IRStructKey* key, IRType* resultType = nullptr); IRType* differentiateType(IRBuilder* builder, IRInst* primalType); - IRInst* tryGetDifferentiableWitness(IRBuilder* builder, IRInst* originalType); - - IRInst* getOrCreateDifferentiablePairWitness(IRBuilder* builder, IRDifferentialPairTypeBase* pairType); - - IRInst* getArrayWitness(IRBuilder* builder, IRArrayType* pairType); - - IRInst* getTupleWitness(IRBuilder* builder, IRInst* tupleType); - - IRInst* getExtractExistensialTypeWitness(IRBuilder* builder, IRExtractExistentialType* extractExistentialType); + IRInst* tryGetDifferentiableWitness(IRBuilder* builder, IRInst* originalType, DiffConformanceKind kind); IRType* getOrCreateDiffPairType(IRBuilder* builder, IRInst* primalType, IRInst* witness); @@ -207,17 +257,21 @@ struct DifferentiableTypeConformanceContext IRInst* getDiffAddMethodFromPairType(IRBuilder* builder, IRDifferentialPairTypeBase* type); + void addTypeToDictionary(IRType* type, IRInst* witness); + + IRInterfaceType* getConformanceTypeFromWitness(IRInst* witness); + IRInst* tryExtractConformanceFromInterfaceType( IRBuilder* builder, IRInterfaceType* interfaceType, IRWitnessTable* witnessTable); - List<IRInterfaceRequirementEntry*> findDifferentiableInterfaceLookupPath( - IRInterfaceType* idiffType, + List<IRInterfaceRequirementEntry*> findInterfaceLookupPath( + IRInterfaceType* supType, IRInterfaceType* type); // Lookup and return the 'Differential' type declared in the concrete type - // in order to conform to the IDifferentiable interface. + // in order to conform to the IDifferentiable/IDifferentiablePtrType interfaces // Note that inside a generic block, this will be a witness table lookup instruction // that gets resolved during the specialization pass. // @@ -227,8 +281,10 @@ struct DifferentiableTypeConformanceContext { case kIROp_InterfaceType: { - if (isDifferentiableType(origType)) + if (isDifferentiableValueType(origType)) return this->sharedContext->differentiableInterfaceType; + else if (isDifferentiablePtrType(origType)) + return this->sharedContext->differentiablePtrInterfaceType; else return nullptr; } @@ -254,13 +310,30 @@ struct DifferentiableTypeConformanceContext auto diffWitness = getDiffTypeWitnessFromPairType(builder, diffPairType); return builder->getDifferentialPairUserCodeType((IRType*)diffType, diffWitness); } + case kIROp_DifferentialPtrPairType: + { + auto diffPairType = as<IRDifferentialPairTypeBase>(origType); + auto diffType = getDiffTypeFromPairType(builder, diffPairType); + auto diffWitness = getDiffTypeWitnessFromPairType(builder, diffPairType); + return builder->getDifferentialPtrPairType((IRType*)diffType, diffWitness); + } default: - return lookUpInterfaceMethod(builder, origType, sharedContext->differentialAssocTypeStructKey); + if (isDifferentiableValueType(origType)) + return lookUpInterfaceMethod(builder, origType, sharedContext->differentialAssocTypeStructKey, builder->getTypeKind()); + else if (isDifferentiablePtrType(origType)) + return lookUpInterfaceMethod(builder, origType, sharedContext->differentialAssocRefTypeStructKey, builder->getTypeKind()); + else + return nullptr; } } bool isDifferentiableType(IRType* origType) { + return isDifferentiableValueType(origType) || isDifferentiablePtrType(origType); + } + + bool isDifferentiableValueType(IRType* origType) + { for (; origType;) { switch (origType->getOp()) @@ -279,7 +352,27 @@ struct DifferentiableTypeConformanceContext origType = (IRType*)origType->getOperand(0); continue; default: - return lookUpConformanceForType(origType) != nullptr; + return lookUpConformanceForType(origType, DiffConformanceKind::Value) != nullptr; + } + } + return false; + } + + bool isDifferentiablePtrType(IRType* origType) + { + for (; origType;) + { + switch (origType->getOp()) + { + case kIROp_VectorType: + case kIROp_ArrayType: + case kIROp_PtrType: + case kIROp_OutType: + case kIROp_InOutType: + origType = (IRType*)origType->getOperand(0); + continue; + default: + return lookUpConformanceForType(origType, DiffConformanceKind::Ptr) != nullptr; } } return false; @@ -287,13 +380,13 @@ struct DifferentiableTypeConformanceContext IRInst* getZeroMethodForType(IRBuilder* builder, IRType* origType) { - auto result = lookUpInterfaceMethod(builder, origType, sharedContext->zeroMethodStructKey); + auto result = lookUpInterfaceMethod(builder, origType, sharedContext->zeroMethodStructKey, sharedContext->zeroMethodType); return result; } IRInst* getAddMethodForType(IRBuilder* builder, IRType* origType) { - auto result = lookUpInterfaceMethod(builder, origType, sharedContext->addMethodStructKey); + auto result = lookUpInterfaceMethod(builder, origType, sharedContext->addMethodStructKey, sharedContext->addMethodType); return result; } @@ -307,8 +400,28 @@ struct DifferentiableTypeConformanceContext IRFunc* getOrCreateExistentialDAddMethod(); + IRInst* buildDifferentiablePairWitness( + IRBuilder* builder, + IRDifferentialPairTypeBase* pairType, + DiffConformanceKind target); + + IRInst* buildArrayWitness( + IRBuilder* builder, + IRArrayType* pairType, + DiffConformanceKind target); + + IRInst* buildTupleWitness( + IRBuilder* builder, + IRInst* tupleType, + DiffConformanceKind target); + + IRInst* buildExtractExistensialTypeWitness( + IRBuilder* builder, + IRExtractExistentialType* extractExistentialType, + DiffConformanceKind target); }; + struct DifferentialPairTypeBuilder { DifferentialPairTypeBuilder() = default; diff --git a/source/slang/slang-ir-check-differentiability.cpp b/source/slang/slang-ir-check-differentiability.cpp index 8b4886a2c..cae47fffd 100644 --- a/source/slang/slang-ir-check-differentiability.cpp +++ b/source/slang/slang-ir-check-differentiability.cpp @@ -625,7 +625,7 @@ public: } } - if (!sharedContext.isInterfaceAvailable) + if (!sharedContext.isInterfaceAvailable && !sharedContext.isPtrInterfaceAvailable) return; for (auto inst : module->getGlobalInsts()) diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 301a9c789..0d689660e 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -61,7 +61,8 @@ INST(Nop, nop, 0, 0) INST(DifferentialPairType, DiffPair, 1, HOISTABLE) INST(DifferentialPairUserCodeType, DiffPairUserCode, 1, HOISTABLE) - INST_RANGE(DifferentialPairTypeBase, DifferentialPairType, DifferentialPairUserCodeType) + INST(DifferentialPtrPairType, DiffRefPair, 1, HOISTABLE) + INST_RANGE(DifferentialPairTypeBase, DifferentialPairType, DifferentialPtrPairType) INST(BackwardDiffIntermediateContextType, BwdDiffIntermediateCtxType, 1, HOISTABLE) @@ -325,15 +326,18 @@ INST(DefaultConstruct, defaultConstruct, 0, 0) INST(MakeDifferentialPair, MakeDiffPair, 2, 0) INST(MakeDifferentialPairUserCode, MakeDiffPairUserCode, 2, 0) -INST_RANGE(MakeDifferentialPairBase, MakeDifferentialPair, MakeDifferentialPairUserCode) +INST(MakeDifferentialPtrPair, MakeDiffRefPair, 2, 0) +INST_RANGE(MakeDifferentialPairBase, MakeDifferentialPair, MakeDifferentialPtrPair) INST(DifferentialPairGetDifferential, GetDifferential, 1, 0) INST(DifferentialPairGetDifferentialUserCode, GetDifferentialUserCode, 1, 0) -INST_RANGE(DifferentialPairGetDifferentialBase, DifferentialPairGetDifferential, DifferentialPairGetDifferentialUserCode) +INST(DifferentialPtrPairGetDifferential, GetDifferentialPtr, 1, 0) +INST_RANGE(DifferentialPairGetDifferentialBase, DifferentialPairGetDifferential, DifferentialPtrPairGetDifferential) INST(DifferentialPairGetPrimal, GetPrimal, 1, 0) INST(DifferentialPairGetPrimalUserCode, GetPrimalUserCode, 1, 0) -INST_RANGE(DifferentialPairGetPrimalBase, DifferentialPairGetPrimal, DifferentialPairGetPrimalUserCode) +INST(DifferentialPtrPairGetPrimal, GetPrimalRef, 1, 0) +INST_RANGE(DifferentialPairGetPrimalBase, DifferentialPairGetPrimal, DifferentialPtrPairGetPrimal) INST(Specialize, specialize, 2, HOISTABLE) INST(LookupWitness, lookupWitness, 2, HOISTABLE) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 37f242e55..f31a56673 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -2959,6 +2959,10 @@ struct IRMakeDifferentialPairUserCode : IRMakeDifferentialPairBase { IR_LEAF_ISA(MakeDifferentialPairUserCode) }; +struct IRMakeDifferentialPtrPair : IRMakeDifferentialPairBase +{ + IR_LEAF_ISA(MakeDifferentialPtrPair) +}; struct IRDifferentialPairGetDifferentialBase : IRInst { @@ -2973,6 +2977,10 @@ struct IRDifferentialPairGetDifferentialUserCode : IRDifferentialPairGetDifferen { IR_LEAF_ISA(DifferentialPairGetDifferentialUserCode) }; +struct IRDifferentialPtrPairGetDifferential : IRDifferentialPairGetDifferentialBase +{ + IR_LEAF_ISA(DifferentialPtrPairGetDifferential) +}; struct IRDifferentialPairGetPrimalBase : IRInst { @@ -2987,6 +2995,10 @@ struct IRDifferentialPairGetPrimalUserCode : IRDifferentialPairGetPrimalBase { IR_LEAF_ISA(DifferentialPairGetPrimalUserCode) }; +struct IRDifferentialPtrPairGetPrimal : IRDifferentialPairGetPrimalBase +{ + IR_LEAF_ISA(DifferentialPtrPairGetPrimal) +}; struct IRDetachDerivative : IRInst { @@ -3657,6 +3669,10 @@ public: IRDifferentialPairType* getDifferentialPairType( IRType* valueType, IRInst* witnessTable); + + IRDifferentialPtrPairType* getDifferentialPtrPairType( + IRType* valueType, + IRInst* witnessTable); IRDifferentialPairUserCodeType* getDifferentialPairUserCodeType( IRType* valueType, @@ -3797,6 +3813,8 @@ public: IRInst* emitGetTorchCudaStream(); IRInst* emitMakeDifferentialPair(IRType* type, IRInst* primal, IRInst* differential); + IRInst* emitMakeDifferentialValuePair(IRType* type, IRInst* primal, IRInst* differential); + IRInst* emitMakeDifferentialPtrPair(IRType* type, IRInst* primal, IRInst* differential); IRInst* emitMakeDifferentialPairUserCode(IRType* type, IRInst* primal, IRInst* differential); IRInst* addDifferentiableTypeDictionaryDecoration(IRInst* target); @@ -3979,9 +3997,19 @@ public: IRInst* emitGetOptionalValue(IRInst* optValue); IRInst* emitMakeOptionalValue(IRInst* optType, IRInst* value); IRInst* emitMakeOptionalNone(IRInst* optType, IRInst* defaultValue); + IRInst* emitDifferentialPairGetDifferential(IRType* diffType, IRInst* diffPair); + IRInst* emitDifferentialValuePairGetDifferential(IRType* diffType, IRInst* diffPair); + IRInst* emitDifferentialPtrPairGetDifferential(IRType* diffType, IRInst* diffPair); + IRInst* emitDifferentialPairGetPrimal(IRInst* diffPair); + IRInst* emitDifferentialValuePairGetPrimal(IRInst* diffPair); + IRInst* emitDifferentialPtrPairGetPrimal(IRInst* diffPair); + IRInst* emitDifferentialPairGetPrimal(IRType* primalType, IRInst* diffPair); + IRInst* emitDifferentialValuePairGetPrimal(IRType* primalType, IRInst* diffPair); + IRInst* emitDifferentialPtrPairGetPrimal(IRType* primalType, IRInst* diffPair); + IRInst* emitDifferentialPairGetDifferentialUserCode(IRType* diffType, IRInst* diffPair); IRInst* emitDifferentialPairGetPrimalUserCode(IRInst* diffPair); IRInst* emitMakeVector( diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 6c7691d13..b89929f55 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -3022,6 +3022,17 @@ namespace Slang operands); } + IRDifferentialPtrPairType* IRBuilder::getDifferentialPtrPairType( + IRType* valueType, + IRInst* witnessTable) + { + IRInst* operands[] = { valueType, witnessTable }; + return (IRDifferentialPtrPairType*)getType( + kIROp_DifferentialPtrPairType, + sizeof(operands) / sizeof(operands[0]), + operands); + } + IRDifferentialPairUserCodeType* IRBuilder::getDifferentialPairUserCodeType( IRType* valueType, IRInst* witnessTable) @@ -3503,7 +3514,7 @@ namespace Slang return inst; } - IRInst* IRBuilder::emitMakeDifferentialPair(IRType* type, IRInst* primal, IRInst* differential) + IRInst* IRBuilder::emitMakeDifferentialValuePair(IRType* type, IRInst* primal, IRInst* differential) { SLANG_RELEASE_ASSERT(as<IRDifferentialPairType>(type)); SLANG_RELEASE_ASSERT(as<IRDifferentialPairType>(type)->getValueType() != nullptr); @@ -3516,6 +3527,98 @@ namespace Slang return inst; } + IRInst* IRBuilder::emitMakeDifferentialPtrPair(IRType* type, IRInst* primal, IRInst* differential) + { + SLANG_RELEASE_ASSERT(as<IRDifferentialPtrPairType>(type)); + SLANG_RELEASE_ASSERT(as<IRDifferentialPtrPairType>(type)->getValueType() != nullptr); + + IRInst* args[] = {primal, differential}; + auto inst = createInstWithTrailingArgs<IRMakeDifferentialPtrPair>( + this, kIROp_MakeDifferentialPtrPair, type, 2, args); + addInst(inst); + return inst; + } + + IRInst* IRBuilder::emitMakeDifferentialPair(IRType* pairType, IRInst* primalVal, IRInst* diffVal) + { + if (as<IRDifferentialPairType>(pairType)) + { + return emitMakeDifferentialValuePair(pairType, primalVal, diffVal); + } + else if (as<IRDifferentialPtrPairType>(pairType)) + { + // Quick optimization: + // If primalVal and diffVal are extracted from the same pointer-pair, + // we can just use the pointer-pair directly. + // + if (auto primalPtrVal = as<IRDifferentialPtrPairGetPrimal>(primalVal)) + { + if (auto diffPtrVal = as<IRDifferentialPtrPairGetDifferential>(diffVal)) + { + if (primalPtrVal->getBase() == diffPtrVal->getBase()) + return primalPtrVal->getBase(); + } + } + return emitMakeDifferentialPtrPair(pairType, primalVal, diffVal); + } + else + { + SLANG_ASSERT(!"unreachable"); + return nullptr; + } + } + + IRInst* IRBuilder::emitDifferentialPairGetDifferential(IRType* diffType, IRInst* pairVal) + { + if (as<IRDifferentialPairType>(pairVal->getDataType())) + { + return emitDifferentialValuePairGetDifferential(diffType, pairVal); + } + else if (as<IRDifferentialPtrPairType>(pairVal->getDataType())) + { + return emitDifferentialPtrPairGetDifferential(diffType, pairVal); + } + else + { + SLANG_ASSERT(!"unreachable"); + return nullptr; + } + } + + IRInst* IRBuilder::emitDifferentialPairGetPrimal(IRInst* pairVal) + { + if (as<IRDifferentialPairType>(pairVal->getDataType())) + { + return emitDifferentialValuePairGetPrimal(pairVal); + } + else if (as<IRDifferentialPtrPairType>(pairVal->getDataType())) + { + return emitDifferentialPtrPairGetPrimal(pairVal); + } + else + { + SLANG_ASSERT(!"unreachable"); + return nullptr; + } + } + + IRInst* IRBuilder::emitDifferentialPairGetPrimal(IRType* primalType, IRInst* pairVal) + { + if (as<IRDifferentialPairType>(pairVal->getDataType())) + { + return emitDifferentialValuePairGetPrimal(primalType, pairVal); + } + else if (as<IRDifferentialPtrPairType>(pairVal->getDataType())) + { + return emitDifferentialPtrPairGetPrimal(primalType, pairVal); + } + else + { + SLANG_ASSERT(!"unreachable"); + return nullptr; + } + } + IRInst* IRBuilder::emitMakeDifferentialPairUserCode(IRType* type, IRInst* primal, IRInst* differential) { SLANG_RELEASE_ASSERT(as<IRDifferentialPairTypeBase>(type)); @@ -4222,7 +4325,7 @@ namespace Slang return emitIntrinsicInst(type, kIROp_MakeVector, argCount, args); } - IRInst* IRBuilder::emitDifferentialPairGetDifferential(IRType* diffType, IRInst* diffPair) + IRInst* IRBuilder::emitDifferentialValuePairGetDifferential(IRType* diffType, IRInst* diffPair) { SLANG_ASSERT(as<IRDifferentialPairTypeBase>(diffPair->getDataType())); return emitIntrinsicInst( @@ -4232,7 +4335,18 @@ namespace Slang &diffPair); } - IRInst* IRBuilder::emitDifferentialPairGetPrimal(IRInst* diffPair) + + IRInst* IRBuilder::emitDifferentialPtrPairGetDifferential(IRType* diffType, IRInst* diffPair) + { + SLANG_ASSERT(as<IRDifferentialPtrPairType>(diffPair->getDataType())); + return emitIntrinsicInst( + diffType, + kIROp_DifferentialPtrPairGetDifferential, + 1, + &diffPair); + } + + IRInst* IRBuilder::emitDifferentialValuePairGetPrimal(IRInst* diffPair) { auto valueType = cast<IRDifferentialPairTypeBase>(diffPair->getDataType())->getValueType(); return emitIntrinsicInst( @@ -4242,7 +4356,7 @@ namespace Slang &diffPair); } - IRInst* IRBuilder::emitDifferentialPairGetPrimal(IRType* primalType, IRInst* diffPair) + IRInst* IRBuilder::emitDifferentialValuePairGetPrimal(IRType* primalType, IRInst* diffPair) { return emitIntrinsicInst( primalType, @@ -4251,6 +4365,25 @@ namespace Slang &diffPair); } + IRInst* IRBuilder::emitDifferentialPtrPairGetPrimal(IRInst* diffPair) + { + auto valueType = cast<IRDifferentialPairTypeBase>(diffPair->getDataType())->getValueType(); + return emitIntrinsicInst( + valueType, + kIROp_DifferentialPtrPairGetPrimal, + 1, + &diffPair); + } + + IRInst* IRBuilder::emitDifferentialPtrPairGetPrimal(IRType* primalType, IRInst* diffPair) + { + return emitIntrinsicInst( + primalType, + kIROp_DifferentialPtrPairGetPrimal, + 1, + &diffPair); + } + IRInst* IRBuilder::emitDifferentialPairGetDifferentialUserCode(IRType* diffType, IRInst* diffPair) { SLANG_ASSERT(as<IRDifferentialPairTypeBase>(diffPair->getDataType())); diff --git a/source/slang/slang-ir.h b/source/slang/slang-ir.h index 375107d1d..14dde200f 100644 --- a/source/slang/slang-ir.h +++ b/source/slang/slang-ir.h @@ -1660,6 +1660,11 @@ struct IRDifferentialPairType : IRDifferentialPairTypeBase IR_LEAF_ISA(DifferentialPairType) }; +struct IRDifferentialPtrPairType : IRDifferentialPairTypeBase +{ + IR_LEAF_ISA(DifferentialPtrPairType) +}; + struct IRDifferentialPairUserCodeType : IRDifferentialPairTypeBase { IR_LEAF_ISA(DifferentialPairUserCodeType) |
