diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2023-08-17 14:45:13 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-08-17 14:45:13 -0400 |
| commit | 945409c4c6871c18aad24086c594cc66b5913733 (patch) | |
| tree | 41eed63f115971d82875e23acbec77d78be4cf3a /source | |
| parent | 216fc18661fd6e05053b4cc864396e6017e85b04 (diff) | |
Initial support for differentiating existential types (#3111)
* Merge
* WIP: Complete auto-diff logic for existential types
* Revert "Add compiler option for generating representative hash"
This reverts commit 13b09ef4621e73844c96d64d9c111a8ed0d45aae.
* More fixes for fwd-mode AD on existential types
* Add anyValueSize inference pass
* Fix checking of `Differential.Differential==Differential`
* In-progress: infer any-value-size for existential types
* Existentials now work in forward-mode
* Overhaul handling of existential AD types. Fwd-mode works, reverse-mode requires front-end changes
* Reverse-mode now works on existentials
* Cleanup
* Remove diff rules for create existential object for now
* Revert treat-as-differentiable changes
* Fixes
* More fixes
* Cleanup
* more cleanup
* signed/unsigned
* Revert "Cleanup"
This reverts commit e4f7d71f07bb207736f90708961eeecd09a1b652.
* Cleanup (again)
* Remove public/export/keep-alive on null differential after AD pass
* Minor fix
* Update dictionary accessors
* Keep export decoration
* More fixes + Support for `kIROp_PackAnyValue`
* Merge upstream
* Update expected-failure.txt
Diffstat (limited to 'source')
29 files changed, 1304 insertions, 173 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index ce0e72d34..423b6bfd0 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -25,6 +25,30 @@ attribute_syntax [DerivativeMember(memberName)] : DerivativeMemberAttribute; __attributeTarget(FunctionDeclBase) attribute_syntax [NoDiffThis] : NoDiffThisAttribute; +// A 'none-type' that acts as a run-time sentinel for zero differentials. +public struct NullDifferential : IDifferentiable +{ + // for now, we'll use at least one field to make sure the type is non-empty + uint dummy; + typedef NullDifferential Differential; + + [Differentiable] + [ForceInline] + static Differential dzero() { return { 0 }; } + + [Differentiable] + [ForceInline] + static Differential dadd(Differential, Differential) { return { 0 }; } + + [Differentiable] + [ForceInline] + static Differential dmul<T: __BuiltinRealType>(T, Differential) { return { 0 }; } +}; + +// Existential check for null differential type +__intrinsic_op($(kIROp_IsDifferentialNull)) +bool isDifferentialNull(IDifferentiable obj); + /// Represents a GPU view of a tensor. __generic<T> __magic_type(TensorViewType) diff --git a/source/slang/slang-ast-decl.h b/source/slang/slang-ast-decl.h index 8266d77c7..553a5c26f 100644 --- a/source/slang/slang-ast-decl.h +++ b/source/slang/slang-ast-decl.h @@ -554,6 +554,9 @@ class DerivativeRequirementDecl : public FunctionDeclBase // The original requirement decl. Decl* originalRequirementDecl = nullptr; + + // Type to use for 'ThisType' + Type* diffThisType; }; // A reference to a synthesized decl representing a differentiable function requirement, this decl will diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index cb3db9e39..f25821dac 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -1656,8 +1656,7 @@ namespace Slang RequirementWitness witnessValue; auto requirementDecl = m_astBuilder->getSharedASTBuilder()->findBuiltinRequirementDecl(BuiltinRequirementKind::DifferentialType); if (!inheritanceDecl->witnessTable->getRequirementDictionary().tryGetValue(requirementDecl, witnessValue)) - return; - + return; // A type used as differential type must have itself as its own differential type. if (witnessValue.getFlavor() != RequirementWitness::Flavor::val) return; @@ -5781,6 +5780,16 @@ namespace Slang interfaceDecl->members.add(reqDecl); reqDecl->parentDecl = interfaceDecl; + if (!decl->hasModifier<NoDiffThisAttribute>()) + { + // Build decl-ref-type from interface. + auto interfaceType = DeclRefType::create(getASTBuilder(), makeDeclRef(interfaceDecl)); + + // If the interface is differentiable, make the this type a pair. + if (tryGetDifferentialType(getASTBuilder(), interfaceType)) + reqDecl->diffThisType = getDifferentialPairType(interfaceType); + } + auto reqRef = m_astBuilder->create<DerivativeRequirementReferenceDecl>(); reqRef->referencedDecl = reqDecl; reqRef->parentDecl = decl; @@ -5800,6 +5809,15 @@ namespace Slang setFuncTypeIntoRequirementDecl(reqDecl, diffFuncType); interfaceDecl->members.add(reqDecl); reqDecl->parentDecl = interfaceDecl; + if (!decl->hasModifier<NoDiffThisAttribute>()) + { + // Build decl-ref-type from interface. + auto interfaceType = DeclRefType::create(getASTBuilder(), makeDeclRef(interfaceDecl)); + + // If the interface is differentiable, make the this type a pair. + if (tryGetDifferentialType(getASTBuilder(), interfaceType)) + reqDecl->diffThisType = getDifferentialPairType(interfaceType); + } auto reqRef = m_astBuilder->create<DerivativeRequirementReferenceDecl>(); reqRef->referencedDecl = reqDecl; diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 05cb6262b..3d2f81edb 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -500,6 +500,12 @@ namespace Slang // Don't synthesize for ThisType. if (as<ThisTypeDecl>(subType->getDeclRef().getDecl())) return nullptr; + + // If the inner most subtype is itself an associated type, then we're dealing + // with an abstract type. There's not need to synthesize anythin at this point. + // + if (as<AssocTypeDecl>(subType->getDeclRef().getDecl())) + return nullptr; // If we reach here, we are expecting a synthesized decl defined in `subType`. // Instead of returning a DeclRefExpr to the requirement decl, we synthesize a placeholder decl diff --git a/source/slang/slang-compiler.h b/source/slang/slang-compiler.h index f29dc8dae..d87b755c7 100755 --- a/source/slang/slang-compiler.h +++ b/source/slang/slang-compiler.h @@ -2671,7 +2671,7 @@ namespace Slang virtual SLANG_NO_THROW void SLANG_MCALL setDebugInfoFormat(SlangDebugInfoFormat format) SLANG_OVERRIDE; virtual SLANG_NO_THROW void SLANG_MCALL setReportDownstreamTime(bool value) SLANG_OVERRIDE; virtual SLANG_NO_THROW void SLANG_MCALL setReportPerfBenchmark(bool value) SLANG_OVERRIDE; - + void setHLSLToVulkanLayoutOptions(int targetIndex, HLSLToVulkanLayoutOptions* vulkanLayoutOptions); EndToEndCompileRequest( diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index 82f2a8fd3..6521b05ba 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -5,6 +5,7 @@ #include "../compiler-core/slang-name.h" +#include "slang-ir-any-value-inference.h" #include "slang-ir-bind-existentials.h" #include "slang-ir-byte-address-legalize.h" #include "slang-ir-collect-global-uniforms.h" diff --git a/source/slang/slang-ir-any-value-inference.cpp b/source/slang/slang-ir-any-value-inference.cpp new file mode 100644 index 000000000..eb4aa670f --- /dev/null +++ b/source/slang/slang-ir-any-value-inference.cpp @@ -0,0 +1,231 @@ +#include "slang-ir-any-value-inference.h" + +#include "slang-ir-generics-lowering-context.h" +#include "slang-ir.h" +#include "slang-ir-insts.h" +#include "slang-ir-util.h" +#include "slang-ir-layout.h" +#include "../core/slang-func-ptr.h" + +namespace Slang +{ + + void _findDependenciesOfTypeInSet(IRType* type, HashSet<IRInterfaceType*>& targetSet, List<IRInterfaceType*>& result) + { + switch (type->getOp()) + { + case kIROp_InterfaceType: + { + auto interfaceType = cast<IRInterfaceType>(type); + if (targetSet.contains(interfaceType)) + { + result.add(interfaceType); + return; + } + } + break; + case kIROp_StructType: + { + auto structType = cast<IRStructType>(type); + for (auto field : structType->getFields()) + { + _findDependenciesOfTypeInSet(field->getFieldType(), targetSet, result); + } + } + break; + default: + { + for (UInt i = 0; i < type->getOperandCount(); i++) + { + if (auto operandType = as<IRType>(type->getOperand(i))) + _findDependenciesOfTypeInSet(operandType, targetSet, result); + } + } + break; + } + } + + List<IRInterfaceType*> findDependenciesOfTypeInSet(IRType* type, HashSet<IRInterfaceType*> targetSet) + { + List<IRInterfaceType*> result; + _findDependenciesOfTypeInSet(type, targetSet, result); + + return result; + } + + void _sortTopologically( + IRInterfaceType* interfaceType, + HashSet<IRInterfaceType*>& visited, + List<IRInterfaceType*>& sortedInterfaceTypes, + const Func<HashSet<IRInterfaceType*>, IRInterfaceType*>& getDependencies) + { + if (visited.contains(interfaceType)) + return; + + visited.add(interfaceType); + + for (auto dependency : getDependencies(interfaceType)) + { + _sortTopologically(dependency, visited, sortedInterfaceTypes, getDependencies); + } + + sortedInterfaceTypes.add(interfaceType); + } + + List<IRInterfaceType*> sortTopologically( + HashSet<IRInterfaceType*> interfaceTypes, + const Func<HashSet<IRInterfaceType*>, IRInterfaceType*>& getDependencies) + { + List<IRInterfaceType*> sortedInterfaceTypes; + HashSet<IRInterfaceType*> visited; + for (auto interfaceType : interfaceTypes) + { + _sortTopologically(interfaceType, visited, sortedInterfaceTypes, getDependencies); + } + return sortedInterfaceTypes; + } + + void inferAnyValueSizeWhereNecessary( + IRModule* module) + { + // Go through the global insts and collect all interface types. + // For each interface type, infer its any-value-size, by looking up + // all witness tables whose conformance type matches the interface type. + // then using _calcNaturalSizeAndAlignment to find the max size. + // + // Note: we only infer any-value-size for interface types that are used + // as a generic type parameter, because we don't want to infer any-value-size + // for interface types that are used as a witness table type. + // + + HashSet<IRInst*> implementedInterfaces; + // Add all interface type that are implemented by at least one type to a set. + for (auto inst : module->getGlobalInsts()) + { + if (inst->getOp() == kIROp_WitnessTable) + { + auto interfaceType = cast<IRWitnessTableType>(inst->getDataType())->getConformanceType(); + implementedInterfaces.add(interfaceType); + } + } + + // Collect all interface types that require inference. + HashSet<IRInterfaceType*> interfaceTypes; + for (auto inst : module->getGlobalInsts()) + { + if (inst->getOp() == kIROp_InterfaceType) + { + auto interfaceType = cast<IRInterfaceType>(inst); + + // Do not infer anything for COM interfaces. + if (isComInterfaceType((IRType*)interfaceType)) + continue; + + // Also skip builtin types. + if (interfaceType->findDecoration<IRBuiltinDecoration>()) + continue; + + // If the interface already has an explicit any-value-size, don't infer anything. + if (interfaceType->findDecoration<IRAnyValueSizeDecoration>()) + continue; + + // Skip interfaces that are not implemented by any type. + if (!implementedInterfaces.contains(interfaceType)) + continue; + + interfaceTypes.add(interfaceType); + } + } + + Dictionary<IRInterfaceType*, List<IRInst*>> mapInterfaceToImplementations; + + // Collect all concrete types that conform to this interface type. + for (auto interfaceType : interfaceTypes) + { + IRWitnessTableType* witnessTableType = nullptr; + // Find witness table type corresponding to this interface. + for (auto use = interfaceType->firstUse; use; use = use->nextUse) + { + if (auto _witnessTableType = as<IRWitnessTableType>(use->getUser())) + { + if (_witnessTableType->getConformanceType() == interfaceType && _witnessTableType->hasUses()) + { + witnessTableType = _witnessTableType; + break; + } + } + } + + // If we hit this case, we have an interface without any conforming implementations. + // This case should be handled before this point. + // + SLANG_ASSERT(witnessTableType); + + List<IRInst*> implList; + + // Walk through all the uses of this witness table type to find the witness tables. + for (auto use = witnessTableType->firstUse; use; use = use->nextUse) + { + auto witnessTable = as<IRWitnessTable>(use->getUser()); + if (!witnessTable || witnessTable->getDataType() != witnessTableType) + continue; + + auto concreteImpl = witnessTable->getConcreteType(); + + // Only consider implementations at the top-level (ignore those nested + // in generics) + // + if (concreteImpl->getParent() == module->getModuleInst()) + implList.add(concreteImpl); + } + + mapInterfaceToImplementations.add(interfaceType, implList); + } + + Dictionary<IRInterfaceType*, HashSet<IRInterfaceType*>> interfaceDependencyMap; + + // Collect dependencies for each interface. + for (auto interfaceType : interfaceTypes) + { + HashSet<IRInterfaceType*> dependencySet; + for (auto impl : mapInterfaceToImplementations[interfaceType]) + { + auto dependencies = findDependenciesOfTypeInSet((IRType*)impl, interfaceTypes); + for (auto dependency : dependencies) + dependencySet.add(dependency); + } + interfaceDependencyMap.add(interfaceType, dependencySet); + } + + // Sort the interface types in topological order. + // This is necessary because we need to infer the any-value-size of an interface type + // before we infer the any-value-size of an interface type that depends on it. + // + List<IRInterfaceType*> sortedInterfaceTypes = sortTopologically(interfaceTypes, [&](IRInterfaceType* interfaceType) + { + return interfaceDependencyMap[interfaceType]; + }); + + for (auto interfaceType : sortedInterfaceTypes) + { + IRIntegerValue maxAnyValueSize = -1; + for (auto implType : mapInterfaceToImplementations[interfaceType]) + { + IRSizeAndAlignment sizeAndAlignment; + getNaturalSizeAndAlignment((IRType*)implType, &sizeAndAlignment); + + maxAnyValueSize = Math::Max(maxAnyValueSize, sizeAndAlignment.size); + } + + // Should not encounter interface types without any conforming implementations. + SLANG_ASSERT(maxAnyValueSize >= 0); + + // If we found a max size, add an any-value-size decoration to the interface type. + if (maxAnyValueSize >= 0) + { + IRBuilder builder(module); + builder.addAnyValueSizeDecoration(interfaceType, maxAnyValueSize); + } + } + } +};
\ No newline at end of file diff --git a/source/slang/slang-ir-any-value-inference.h b/source/slang/slang-ir-any-value-inference.h new file mode 100644 index 000000000..eb202d626 --- /dev/null +++ b/source/slang/slang-ir-any-value-inference.h @@ -0,0 +1,13 @@ +// slang-ir-any-value-inference.h +#pragma once + +#include "../core/slang-common.h" +#include "slang-ir.h" +#include "slang-ir-insts.h" +#include "slang-compiler.h" + +namespace Slang +{ + void inferAnyValueSizeWhereNecessary( + IRModule* module); +} diff --git a/source/slang/slang-ir-any-value-marshalling.cpp b/source/slang/slang-ir-any-value-marshalling.cpp index 79aea9011..ed7818dbf 100644 --- a/source/slang/slang-ir-any-value-marshalling.cpp +++ b/source/slang/slang-ir-any-value-marshalling.cpp @@ -1,5 +1,6 @@ #include "slang-ir-any-value-marshalling.h" +#include "../core/slang-math.h" #include "slang-ir-generics-lowering-context.h" #include "slang-ir.h" #include "slang-ir-insts.h" @@ -782,6 +783,46 @@ namespace Slang auto size = SharedGenericsLoweringContext::getInterfaceAnyValueSize(interfaceType, interfaceType->sourceLoc); return alignUp(offset, 4) + alignUp((SlangInt)size, 4); } + case kIROp_LookupWitness: + { + auto witnessTableVal = type->getOperand(0); + auto key = type->getOperand(1); + IRType* assocType = nullptr; + if (auto witnessTableType = as<IRWitnessTableTypeBase>(witnessTableVal->getDataType())) + { + auto interfaceType = as<IRInterfaceType>(witnessTableType->getConformanceType()); + + // Walk through interface operands to find a match, the result should be an + // associated type entry. + // + for (UIndex ii = 0; ii < interfaceType->getOperandCount(); ii++) + { + auto entry = cast<IRInterfaceRequirementEntry>(interfaceType->getOperand(ii)); + if (entry->getRequirementKey() == key && + as<IRAssociatedType>(entry->getRequirementVal())) + { + assocType = (IRType*)entry->getRequirementVal(); + break; + } + } + } + + if (!assocType) + return -1; + + IRIntegerValue anyValueSize = kInvalidAnyValueSize; + for (UInt i = 0; i < assocType->getOperandCount(); i++) + { + anyValueSize = Math::Min( + anyValueSize, + SharedGenericsLoweringContext::getInterfaceAnyValueSize(assocType->getOperand(i), type->sourceLoc)); + } + + if (anyValueSize == kInvalidAnyValueSize) + return -1; + + return alignUp(offset, 4) + alignUp((SlangInt)anyValueSize, 4); + } default: if (as<IRTextureTypeBase>(type) || as<IRSamplerStateTypeBase>(type)) { diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index c17d7d5c4..2662498ed 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -142,6 +142,23 @@ InstPair ForwardDiffTranscriber::transcribeUndefined(IRBuilder* builder, IRInst* return InstPair(primalVal, nullptr); } +InstPair ForwardDiffTranscriber::transcribeReinterpret(IRBuilder* builder, IRInst* origInst) +{ + auto primalVal = maybeCloneForPrimalInst(builder, origInst); + + IRInst* diffVal = nullptr; + + if (IRType* const diffType = differentiateType(builder, origInst->getFullType())) + { + if (auto diffOperand = findOrTranscribeDiffInst(builder, origInst->getOperand(0))) + { + diffVal = builder->emitReinterpret(diffType, diffOperand); + } + } + + return InstPair(primalVal, diffVal); +} + InstPair ForwardDiffTranscriber::transcribeVar(IRBuilder* builder, IRVar* origVar) { if (IRType* diffType = differentiateType(builder, origVar->getDataType()->getValueType())) @@ -230,10 +247,12 @@ InstPair ForwardDiffTranscriber::transcribeBinaryArith(IRBuilder* builder, IRIns diffLeft, builder->getFloatValue( constant->getDataType(), 1.0 / constant->getValue())); + builder->markInstAsDifferential(diff, resultType); } else { diff = builder->emitDiv(diffType, diffLeft, primalRight); + builder->markInstAsDifferential(diff, resultType); } return InstPair(primalArith, diff); } @@ -247,6 +266,7 @@ InstPair ForwardDiffTranscriber::transcribeBinaryArith(IRBuilder* builder, IRIns auto diffSub = builder->emitSub(diffType, diffLeftTimesRight, diffRightTimesLeft); builder->markInstAsDifferential(diffSub, resultType); + auto diffMul = builder->emitMul(primalRight->getFullType(), primalRight, primalRight); builder->markInstAsPrimal(diffMul); @@ -661,7 +681,7 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig auto pairPtrType = as<IRPtrTypeBase>(pairType); auto pairValType = as<IRDifferentialPairType>( pairPtrType ? pairPtrType->getValueType() : pairType); - auto diffType = differentiableTypeConformanceContext.getDifferentialTypeFromDiffPairType(&argBuilder, pairValType); + auto diffType = differentiableTypeConformanceContext.getDiffTypeFromPairType(&argBuilder, pairValType); if (auto ptrParamType = as<IRPtrTypeBase>(diffParamType)) { // Create temp var to pass in/out arguments. @@ -698,6 +718,7 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig { auto newDiffVal = afterBuilder.emitDifferentialPairGetDifferential((IRType*)diffType, newVal); afterBuilder.markInstAsDifferential(newDiffVal, pairValType->getValueType()); + auto storeInst = afterBuilder.emitStore(diffArg, newDiffVal); afterBuilder.markInstAsDifferential(storeInst, pairValType->getValueType()); } @@ -1389,16 +1410,19 @@ InstPair ForwardDiffTranscriber::transcribeMakeExistential(IRBuilder* builder, I SLANG_RELEASE_ASSERT(primalInterfaceType); // If the interface type of the existential is differentiable, we emit a make existential - // of IDifferentiable interface type and the witness table of the original type's conformance + // of IDifferentiable.Differential type and the witness table of the original type's conformance // to IDifferentiable. // - if (auto differentialWitnessTable = tryExtractConformanceFromInterfaceType( + if (auto differentialWitnessTable = differentiableTypeConformanceContext.tryExtractConformanceFromInterfaceType( builder, primalInterfaceType, (IRWitnessTable*)primalWitnessTable)) { if (auto diffBase = findOrTranscribeDiffInst(builder, origBase)) { + auto differentialAssociatedType = differentiateType(builder, primalInterfaceType); + SLANG_ASSERT(differentialAssociatedType); + diffResult = builder->emitMakeExistential( - autoDiffSharedContext->differentiableInterfaceType, + differentialAssociatedType, diffBase, differentialWitnessTable); } @@ -1735,6 +1759,7 @@ InstPair ForwardDiffTranscriber::transcribeFunc(IRBuilder* inBuilder, IRFunc* pr { auto diffVal = builder.emitLoad(writeBack.value.differential); builder.markInstAsDifferential(diffVal, primalVal->getFullType()); + valToStore = builder.emitMakeDifferentialPair(cast<IRPtrTypeBase>(param->getFullType())->getValueType(), primalVal, diffVal); builder.markInstAsMixedDifferential(valToStore, valToStore->getFullType()); @@ -1867,6 +1892,9 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* case kIROp_ExtractExistentialValue: return transcribeSingleOperandInst(builder, origInst); + + case kIROp_PackAnyValue: + return transcribeSingleOperandInst(builder, origInst); case kIROp_MakeExistential: return transcribeMakeExistential(builder, as<IRMakeExistential>(origInst)); @@ -1874,10 +1902,16 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* case kIROp_ExtractExistentialType: { IRInst* witnessTable; + auto diffType = differentiateExtractExistentialType( + builder, as<IRExtractExistentialType>(origInst), witnessTable); + + // Mark types as primal since they are not transposable. + if (diffType) + builder->markInstAsPrimal(diffType); + return InstPair( maybeCloneForPrimalInst(builder, origInst), - differentiateExtractExistentialType( - builder, as<IRExtractExistentialType>(origInst), witnessTable)); + diffType); } case kIROp_ExtractExistentialWitnessTable: return transcribeExtractExistentialWitnessTable(builder, origInst); @@ -1890,6 +1924,9 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* case kIROp_undefined: return transcribeUndefined(builder, origInst); + + case kIROp_Reinterpret: + return transcribeReinterpret(builder, origInst); // Differentiable insts that should have been lowered in a previous pass. case kIROp_SwizzledStore: @@ -1901,7 +1938,6 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* SLANG_RELEASE_ASSERT(lookupDiffInst(swizzledStore->getDest(), nullptr) == nullptr); return transcribeNonDiffInst(builder, swizzledStore); } - // Known non-differentiable insts. case kIROp_Not: case kIROp_BitAnd: @@ -1918,12 +1954,10 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* case kIROp_RWStructuredBufferLoadStatus: case kIROp_RWStructuredBufferStore: case kIROp_RWStructuredBufferGetElementPtr: - case kIROp_Reinterpret: case kIROp_IsType: case kIROp_ImageSubscript: case kIROp_ImageLoad: case kIROp_ImageStore: - case kIROp_PackAnyValue: case kIROp_UnpackAnyValue: case kIROp_GetNativePtr: case kIROp_CastIntToFloat: @@ -1936,6 +1970,11 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst* // A call to createDynamicObject<T>(arbitraryData) cannot provide a diff value, // so we treat this inst as non differentiable. // We can extend the frontend and IR with a separate op-code that can provide an explicit diff value. + // + // However, we can't skip this instruction since it also produces a _type_ which may be used by + // other differentiable instructions. Therefore, we'll create another existential object but with + // a dzero() for it's value. + // case kIROp_CreateExistentialObject: return transcribeNonDiffInst(builder, origInst); diff --git a/source/slang/slang-ir-autodiff-fwd.h b/source/slang/slang-ir-autodiff-fwd.h index 4edb9301a..8d8d65c10 100644 --- a/source/slang/slang-ir-autodiff-fwd.h +++ b/source/slang/slang-ir-autodiff-fwd.h @@ -88,6 +88,8 @@ struct ForwardDiffTranscriber : AutoDiffTranscriberBase InstPair transcribeDefaultConstruct(IRBuilder* builder, IRInst* origInst); + InstPair transcribeReinterpret(IRBuilder* builder, IRInst* origInst); + virtual IRFuncType* differentiateFunctionType(IRBuilder* builder, IRInst* func, IRFuncType* funcType) override; void generateTrivialFwdDiffFunc(IRFunc* primalFunc, IRFunc* diffFunc); diff --git a/source/slang/slang-ir-autodiff-primal-hoist.cpp b/source/slang/slang-ir-autodiff-primal-hoist.cpp index 6ccf7caf4..ebf7a9484 100644 --- a/source/slang/slang-ir-autodiff-primal-hoist.cpp +++ b/source/slang/slang-ir-autodiff-primal-hoist.cpp @@ -3,6 +3,7 @@ #include "slang-ir-autodiff-region.h" #include "slang-ir-simplify-cfg.h" #include "slang-ir-util.h" +#include "../core/slang-func-ptr.h" #include "slang-ir.h" namespace Slang @@ -1087,8 +1088,12 @@ IRVar* emitIndexedLocalVar( IRType* baseType, const List<IndexTrackingInfo>& defBlockIndices) { + // Cannot store pointers. Case should have been handled by now. SLANG_RELEASE_ASSERT(!as<IRPtrTypeBase>(baseType)); + // Cannot store types. Case should have been handled by now. + SLANG_RELEASE_ASSERT(!as<IRTypeType>(baseType)); + IRBuilder varBuilder(varBlock->getModule()); varBuilder.setInsertBefore(varBlock->getFirstOrdinaryInst()); @@ -1242,23 +1247,112 @@ static int getInstRegionNestLevel( return (int)result; } + +struct UseChain +{ + List<IRUse*> chain; + static List<UseChain> from( + IRUse* baseUse, + Func<bool, IRUse*> isRelevantUse, + Func<bool, IRInst*> passthroughInst) + { + IRInst* inst = baseUse->getUser(); + + // Base case 1: we hit a relevant use, return a single-element chain. + if (isRelevantUse(baseUse)) + { + UseChain baseUseChain; + baseUseChain.chain.add(baseUse); + + return List<UseChain>(UseChain(baseUseChain)); + } + + // Base case 2: we hit an irrelevant use that is not also a passthrough. + // so stop here. + if (!passthroughInst(inst)) + { + return List<UseChain>(); + } + + // Recurse. + List<UseChain> result; + for (auto use = inst->firstUse; use; use = use->nextUse) + { + List<UseChain> innerChain = from(use, isRelevantUse, passthroughInst); + + for (auto& useChain : innerChain) + { + useChain.chain.add(baseUse); + result.add(useChain); + } + } + + return result; + } + + void replace(IRBuilder* builder, IRInst* inst) + { + SLANG_ASSERT(chain.getCount() > 0); + + // Simple case: if there is only one use, then we can just replace it. + if (chain.getCount() == 1) + { + builder->replaceOperand(chain.getLast(), inst); + chain.clear(); + return; + } + + IRCloneEnv env; + + // Pop the last use, which is the base use that needs to be replaced. + auto baseUse = chain.getLast(); + chain.removeLast(); + + // Ensure that replacement inst is set as mapping for the baseUse. + env.mapOldValToNew[baseUse->get()] = inst; + + auto lastInstInChain = inst; + + IRBuilder chainBuilder(builder->getModule()); + setInsertAfterOrdinaryInst(&chainBuilder, inst); + + // Clone the rest of the chain. + for (auto& use : chain) + { + lastInstInChain = cloneInst(&env, &chainBuilder, use->getUser()); + } + + // Replace the base use. + builder->replaceOperand(baseUse, lastInstInChain); + + chain.clear(); + } + + IRInst* getUser() const + { + SLANG_ASSERT(chain.getCount() > 0); + return chain.getLast()->getUser(); + } +}; + + // Trim defBlockIndices based on the indices of out of scope uses. // static List<IndexTrackingInfo> maybeTrimIndices( const List<IndexTrackingInfo>& defBlockIndices, const Dictionary<IRBlock*, List<IndexTrackingInfo>>& indexedBlockInfo, - const List<IRUse*>& outOfScopeUses) + const List<UseChain>& outOfScopeUses) { // Go through uses, lookup the defBlockIndices, and remove any indices if they // are not present in any of the uses. (This is sort of slow...) // List<IndexTrackingInfo> result; - for (auto& index : defBlockIndices) + for (const auto& index : defBlockIndices) { bool found = false; - for (auto& use : outOfScopeUses) + for (const auto& use : outOfScopeUses) { - auto useInst = use->getUser(); + auto useInst = use.getUser(); auto useBlock = useInst->getParent(); auto useBlockIndices = indexedBlockInfo.getValue(as<IRBlock>(useBlock)); if (useBlockIndices.contains(index)) @@ -1273,6 +1367,18 @@ static List<IndexTrackingInfo> maybeTrimIndices( return result; } +bool canInstBeStored(IRInst* inst) +{ + // Cannot store insts whose value is a type or a witness table. + // These insts get lowered to target-specific logic, and cannot be + // stored into variables or context structs as normal values. + // + if (as<IRTypeType>(inst->getDataType()) || as<IRWitnessTableType>(inst->getDataType())) + return false; + + return true; +} + /// Legalizes all accesses to primal insts from recompute and diff blocks. /// @@ -1352,8 +1458,19 @@ RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability( { SLANG_ASSERT(!isDifferentialBlock(defaultVarBlock)); - for (auto instToStore : instSet) + List<IRInst*> workList; + for (auto inst : instSet) + workList.add(inst); + + HashSet<IRInst*> seenInstSet; + while (workList.getCount() != 0) { + auto instToStore = workList.getLast(); + workList.removeLast(); + + if (seenInstSet.contains(instToStore)) + continue; + IRBlock* defBlock = nullptr; if (const auto ptrInst = as<IRPtrTypeBase>(instToStore->getDataType())) { @@ -1367,45 +1484,61 @@ RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability( SLANG_RELEASE_ASSERT(defBlock); - List<IRUse*> outOfScopeUses; + List<UseChain> outOfScopeUses; for (auto use = instToStore->firstUse; use;) { auto nextUse = use->nextUse; - // Only consider uses in differential blocks. - // This method is not responsible for other blocks. - // - IRBlock* userBlock = getBlock(use->getUser()); - if (isDifferentialOrRecomputeBlock(userBlock)) + // Lambda to check if a use is relevant. + auto isRelevantUse = [&](IRUse* use) { - if (!domTree->dominates(defBlock, userBlock)) - { - outOfScopeUses.add(use); - } - else if (!areIndicesSubsetOf(indexedBlockInfo[defBlock], indexedBlockInfo[userBlock])) - { - outOfScopeUses.add(use); - } - else if (getInstRegionNestLevel(indexedBlockInfo, defBlock, instToStore) > 0 && - !isDifferentialOrRecomputeBlock(defBlock)) + // Only consider uses in differential blocks. + // This method is not responsible for other blocks. + // + IRBlock* userBlock = getBlock(use->getUser()); + if (isDifferentialOrRecomputeBlock(userBlock)) { - outOfScopeUses.add(use); - } - else if (as<IRPtrTypeBase>(instToStore->getDataType()) && - !isDifferentialOrRecomputeBlock(defBlock)) - { - outOfScopeUses.add(use); + if (!domTree->dominates(defBlock, userBlock)) + { + return true; + } + else if (!areIndicesSubsetOf(indexedBlockInfo[defBlock], indexedBlockInfo[userBlock])) + { + return true; + } + else if (getInstRegionNestLevel(indexedBlockInfo, defBlock, instToStore) > 0 && + !isDifferentialOrRecomputeBlock(defBlock)) + { + return true; + } + else if (as<IRPtrTypeBase>(instToStore->getDataType()) && + !isDifferentialOrRecomputeBlock(defBlock)) + { + return true; + } } - } + return false; + }; + + // Lambda to check if an inst is transparent. We lookup uses 'through' transparent + // insts recursively. + // + auto isPassthroughInst = [&](IRInst* inst) + { + return !canInstBeStored(inst); + }; + + List<UseChain> useChains = UseChain::from(use, isRelevantUse, isPassthroughInst); + outOfScopeUses.addRange(useChains); use = nextUse; } if (outOfScopeUses.getCount() == 0) { - if (!isRecomputeInst) processedStoreSet.add(instToStore); + seenInstSet.add(instToStore); continue; } @@ -1457,9 +1590,9 @@ RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability( for (auto use : outOfScopeUses) { - setInsertBeforeOrdinaryInst(&builder, getInstInBlock(use->getUser())); + setInsertBeforeOrdinaryInst(&builder, getInstInBlock(use.getUser())); - List<IndexTrackingInfo>& useBlockIndices = indexedBlockInfo[getBlock(use->getUser())]; + List<IndexTrackingInfo>& useBlockIndices = indexedBlockInfo[getBlock(use.getUser())]; IRInst* loadAddr = emitIndexedLoadAddressForVar( &builder, @@ -1467,12 +1600,37 @@ RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability( defBlock, defBlockIndices, useBlockIndices); - builder.replaceOperand(use, loadAddr); + use.replace(&builder, loadAddr); } if (!isRecomputeInst) processedStoreSet.add(localVar); } + else if (!canInstBeStored(instToStore)) + { + // We won't actually process these insts here. Instead we'll + // simply make sure that their operands are either already present + // in the worklist or add them to the worklist for legalization. + // + + List<IRInst*> pendingOperands; + for (UIndex ii = 0; ii < instToStore->getOperandCount(); ii++) + { + auto operand = instToStore->getOperand(ii); + if (!instSet.contains(operand) && !seenInstSet.contains(operand)) + { + if(getBlock(operand) && + (getBlock(operand)->getParent() == getBlock(instToStore)->getParent())) + pendingOperands.add(operand); + } + } + + if (pendingOperands.getCount() > 0) + { + for (Index ii = pendingOperands.getCount() - 1; ii >= 0; --ii) + workList.add(pendingOperands[ii]); + } + } else { // Handle the special case of loop counters. @@ -1495,16 +1653,18 @@ RefPtr<HoistedPrimalsInfo> ensurePrimalAvailability( for (auto use : outOfScopeUses) { - List<IndexTrackingInfo> useBlockIndices = indexedBlockInfo[getBlock(use->getUser())]; - setInsertBeforeOrdinaryInst(&builder, getInstInBlock(use->getUser())); - builder.replaceOperand( - use, + List<IndexTrackingInfo> useBlockIndices = indexedBlockInfo[getBlock(use.getUser())]; + setInsertBeforeOrdinaryInst(&builder, getInstInBlock(use.getUser())); + use.replace( + &builder, loadIndexedValue(&builder, localVar, defBlock, defBlockIndices, useBlockIndices)); } if (!isRecomputeInst) processedStoreSet.add(localVar); } + + seenInstSet.add(instToStore); } }; diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index 532fb88ac..8d7582373 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -913,7 +913,7 @@ namespace Slang { primalType = diffPairType->getValueType(); diffType = (IRType*)differentiableTypeConformanceContext - .getDifferentialTypeFromDiffPairType(builder, diffPairType); + .getDiffTypeFromPairType(builder, diffPairType); } // Now we handle each combination of parameter direction x differentiability. diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index 0a9ff51a4..24e26f943 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -369,17 +369,30 @@ IRType* AutoDiffTranscriberBase::getOrCreateDiffPairType(IRBuilder* builder, IRI autoDiffSharedContext->differentialAssocTypeWitnessStructKey); } } + + // Obtain the witness that primalType conforms to IDifferentiable. if (!witness) witness = tryGetDifferentiableWitness(builder, originalType); SLANG_RELEASE_ASSERT(witness); - return builder->getDifferentialPairType( + auto pairType = builder->getDifferentialPairType( (IRType*)primalType, witness); + + return pairType; } IRType* AutoDiffTranscriberBase::differentiateType(IRBuilder* builder, IRType* origType) { + // Special-case for differentiable existential types. + if (as<IRInterfaceType>(origType) || as<IRAssociatedType>(origType)) + { + if (differentiableTypeConformanceContext.lookUpConformanceForType(origType)) + return autoDiffSharedContext->differentiableInterfaceType; + else + return nullptr; + } + auto primalType = lookupPrimalInst(builder, origType, origType); if (primalType->getOp() == kIROp_Param && primalType->getParent() && primalType->getParent()->getParent() && @@ -482,72 +495,17 @@ IRType* AutoDiffTranscriberBase::_differentiateTypeImpl(IRBuilder* builder, IRTy } } -// Given an interface type, return the lookup path from a witness table of `type` to a witness table of `IDifferentiable`. -static bool _findDifferentiableInterfaceLookupPathImpl( - HashSet<IRInst*>& processedTypes, - IRInterfaceType* idiffType, - IRInterfaceType* type, - List<IRInterfaceRequirementEntry*>& currentPath) -{ - if (processedTypes.contains(type)) - return false; - processedTypes.add(type); - - List<IRInterfaceRequirementEntry*> lookupKeyPath; - for (UInt i = 0; i < type->getOperandCount(); i++) - { - auto entry = as<IRInterfaceRequirementEntry>(type->getOperand(i)); - if (!entry) continue; - if (auto wt = as<IRWitnessTableTypeBase>(entry->getRequirementVal())) - { - currentPath.add(entry); - if (wt->getConformanceType() == idiffType) - { - return true; - } - else if (auto subInterfaceType = as<IRInterfaceType>(wt->getConformanceType())) - { - if (_findDifferentiableInterfaceLookupPathImpl(processedTypes, idiffType, subInterfaceType, currentPath)) - return true; - } - currentPath.removeLast(); - } - } - return false; -} - -List<IRInterfaceRequirementEntry*> AutoDiffTranscriberBase::findDifferentiableInterfaceLookupPath( - IRInterfaceType* idiffType, - IRInterfaceType* type) -{ - List<IRInterfaceRequirementEntry*> currentPath; - HashSet<IRInst*> processedTypes; - _findDifferentiableInterfaceLookupPathImpl(processedTypes, idiffType, type, currentPath); - return currentPath; -} - -IRInst* AutoDiffTranscriberBase::tryExtractConformanceFromInterfaceType( - IRBuilder* builder, - IRInterfaceType* interfaceType, - IRWitnessTable* witnessTable) +bool AutoDiffTranscriberBase::isExistentialType(IRType *type) { - SLANG_RELEASE_ASSERT(interfaceType); - - List<IRInterfaceRequirementEntry*> lookupKeyPath = findDifferentiableInterfaceLookupPath( - autoDiffSharedContext->differentiableInterfaceType, interfaceType); - - IRInst* differentialTypeWitness = witnessTable; - if (lookupKeyPath.getCount()) + switch (type->getOp()) { - // `interfaceType` does conform to `IDifferentiable`. - for (auto node : lookupKeyPath) - { - differentialTypeWitness = builder->emitLookupInterfaceMethodInst((IRType*)node->getRequirementVal(), differentialTypeWitness, node->getRequirementKey()); - } - return differentialTypeWitness; + case kIROp_ExtractExistentialType: + case kIROp_InterfaceType: + case kIROp_AssociatedType: + return true; + default: + return false; } - - return nullptr; } InstPair AutoDiffTranscriberBase::transcribeExtractExistentialWitnessTable(IRBuilder* builder, IRInst* origInst) @@ -569,7 +527,7 @@ InstPair AutoDiffTranscriberBase::transcribeExtractExistentialWitnessTable(IRBui if (!interfaceType) return InstPair(primalResult, nullptr); - if (auto differentialWitnessTable = tryExtractConformanceFromInterfaceType( + if (auto differentialWitnessTable = differentiableTypeConformanceContext.tryExtractConformanceFromInterfaceType( builder, interfaceType, (IRWitnessTable*)primalResult)) { // `interfaceType` does conform to `IDifferentiable`. @@ -630,7 +588,7 @@ IRType* AutoDiffTranscriberBase::differentiateExtractExistentialType(IRBuilder* auto interfaceType = as<IRInterfaceType>(unwrapAttributedType(origType->getOperand(0)->getDataType())); if (!interfaceType) return nullptr; - List<IRInterfaceRequirementEntry*> lookupKeyPath = findDifferentiableInterfaceLookupPath( + List<IRInterfaceRequirementEntry*> lookupKeyPath = differentiableTypeConformanceContext.findDifferentiableInterfaceLookupPath( autoDiffSharedContext->differentiableInterfaceType, interfaceType); if (lookupKeyPath.getCount()) @@ -737,6 +695,13 @@ InstPair AutoDiffTranscriberBase::transcribeLookupInterfaceMethod(IRBuilder* bui (IRType*)primalDiffType, primal, autoDiffSharedContext->differentialAssocTypeWitnessStructKey); + + // Mark both as primal since we're working with types + // (which don't need transposing) + // + builder->markInstAsPrimal(primalDiffType); + builder->markInstAsPrimal(diffWitness); + return InstPair(primal, diffWitness); } } @@ -762,12 +727,31 @@ InstPair AutoDiffTranscriberBase::transcribeLookupInterfaceMethod(IRBuilder* bui // result, it's useful to have a method to generate zero literals of any (arithmetic) type. // The current implementation requires that types are defined linearly. // -IRInst* AutoDiffTranscriberBase::getDifferentialZeroOfType(IRBuilder* builder, IRType* originalType) +IRInst* AutoDiffTranscriberBase::getDifferentialZeroOfType( + IRBuilder* builder, IRType* originalType) { originalType = (IRType*)unwrapAttributedType(originalType); auto primalType = (IRType*)lookupPrimalInst(builder, originalType); if (auto diffType = differentiateType(builder, originalType)) { + IRInst* diffWitnessTable = nullptr; + IRType* diffOuterType = nullptr; + if (isExistentialType(diffType)) + { + // Emit null differential & pack it into an IDifferentiable existential. + + auto nullDiffValue = differentiableTypeConformanceContext.emitNullDifferential(builder); + builder->markInstAsDifferential(nullDiffValue, autoDiffSharedContext->nullDifferentialStructType); + + auto nullDiffExistential = builder->emitMakeExistential( + diffType, + nullDiffValue, + autoDiffSharedContext->nullDifferentialWitness); + builder->markInstAsDifferential(nullDiffExistential, primalType); + + return nullDiffExistential; + } + switch (diffType->getOp()) { case kIROp_DifferentialPairType: @@ -812,7 +796,7 @@ IRInst* AutoDiffTranscriberBase::getDifferentialZeroOfType(IRBuilder* builder, I // zero method from the same witness table. auto wt = lookupInterface->getWitnessTable(); zeroMethod = builder->emitLookupInterfaceMethodInst(builder->getFuncType(List<IRType*>(), diffType), wt, autoDiffSharedContext->zeroMethodStructKey); - builder->markInstAsDifferential(zeroMethod); + builder->markInstAsPrimal(zeroMethod); } else { @@ -825,7 +809,18 @@ IRInst* AutoDiffTranscriberBase::getDifferentialZeroOfType(IRBuilder* builder, I auto callInst = builder->emitCallInst((IRType*)diffType, zeroMethod, emptyArgList); builder->markInstAsDifferential(callInst, primalType); - return callInst; + if (diffOuterType && isExistentialType(diffOuterType)) + { + // Need to wrap the result back into an existential. + auto existentialZero = builder->emitMakeExistential( + diffOuterType, + callInst, + diffWitnessTable); + builder->markInstAsDifferential(existentialZero, primalType); + return existentialZero; + } + else + return callInst; } else { diff --git a/source/slang/slang-ir-autodiff-transcriber-base.h b/source/slang/slang-ir-autodiff-transcriber-base.h index d6b2ea9ff..e9acbcd99 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.h +++ b/source/slang/slang-ir-autodiff-transcriber-base.h @@ -87,13 +87,8 @@ struct AutoDiffTranscriberBase IRInst* maybeCloneForPrimalInst(IRBuilder* builder, IRInst* inst); - List<IRInterfaceRequirementEntry*> findDifferentiableInterfaceLookupPath( - IRInterfaceType* idiffType, IRInterfaceType* type); - InstPair transcribeExtractExistentialWitnessTable(IRBuilder* builder, IRInst* origInst); - IRInst* tryExtractConformanceFromInterfaceType(IRBuilder* builder, IRInterfaceType* type, IRWitnessTable* WitnessTable); - void maybeMigrateDifferentiableDictionaryFromDerivativeFunc(IRBuilder* builder, IRInst* origFunc); // Get or construct `:IDifferentiable` conformance for a DifferentiablePair. @@ -141,6 +136,10 @@ struct AutoDiffTranscriberBase IRType* _differentiateTypeImpl(IRBuilder* builder, IRType* origType); + bool isExistentialType(IRType* type); + + void _markInstAsDifferential(IRBuilder* builder, IRInst* diffInst, IRInst* primalInst = nullptr); + virtual IRFuncType* differentiateFunctionType(IRBuilder* builder, IRInst* func, IRFuncType* funcType) = 0; // Create an empty func to represent the transcribed func of `origFunc`. diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h index dad4ab192..bcebd2108 100644 --- a/source/slang/slang-ir-autodiff-transpose.h +++ b/source/slang/slang-ir-autodiff-transpose.h @@ -583,15 +583,19 @@ struct DiffTransposePass auto nextInst = inst->getNextInst(); if (auto varInst = as<IRVar>(inst)) { - if (auto diffDecor = varInst->findDecoration<IRDifferentialInstDecoration>()) + if (isDifferentialInst(varInst) && tryGetPrimalTypeFromDiffInst(varInst)) { - if (auto ptrPrimalType = as<IRPtrTypeBase>(diffDecor->getPrimalType())) + if (auto ptrPrimalType = as<IRPtrTypeBase>(tryGetPrimalTypeFromDiffInst(varInst))) { varInst->insertAtEnd(firstRevDiffBlock); auto dzero = emitDZeroOfDiffInstType(&builder, ptrPrimalType->getValueType()); builder.emitStore(varInst, dzero); } + else + { + SLANG_UNEXPECTED("Expected an pointer-typed differential variable."); + } } } inst = nextInst; @@ -1139,21 +1143,15 @@ struct DiffTransposePass // Normal differentiable input parameter will become an inout DiffPair parameter // in the propagate func. The split logic has already prepared the initial value // to pass in. We need to define a temp variable with this initial value and pass - // in the temp variable as argument to the inout parameter. + // in the temp variable as argument to the inout parameter. auto makePairArg = as<IRMakeDifferentialPair>(arg); SLANG_RELEASE_ASSERT(makePairArg); auto pairType = as<IRDifferentialPairType>(arg->getDataType()); auto var = builder->emitVar(arg->getDataType()); - - auto diffType = (IRType*)diffTypeContext.getDiffTypeFromPairType(builder, pairType); - auto zeroMethod = diffTypeContext.getDiffZeroMethodFromPairType(builder, pairType); - SLANG_ASSERT(zeroMethod); - auto diffZero = builder->emitCallInst( - diffType, - zeroMethod, - List<IRInst*>()); + + auto diffZero = emitDZeroOfDiffInstType(builder, pairType->getValueType()); // Initialize this var to (arg.primal, 0). builder->emitStore( @@ -1484,6 +1482,18 @@ struct DiffTransposePass case kIROp_FloatCast: return transposeFloatCast(builder, fwdInst, revValue); + case kIROp_MakeExistential: + return transposeMakeExistential(builder, fwdInst, revValue); + + case kIROp_ExtractExistentialValue: + return transposeExtractExistentialValue(builder, fwdInst, revValue); + + case kIROp_Reinterpret: + return transposeReinterpret(builder, fwdInst, revValue); + + case kIROp_PackAnyValue: + return transposePackAnyValue(builder, fwdInst, revValue); + case kIROp_LoadReverseGradient: case kIROp_ReverseGradientDiffPairRef: case kIROp_DefaultConstruct: @@ -1495,7 +1505,6 @@ struct DiffTransposePass case kIROp_Switch: case kIROp_LookupWitness: case kIROp_ExtractExistentialType: - case kIROp_ExtractExistentialValue: case kIROp_ExtractExistentialWitnessTable: { // Ignore. transposeBlock() should take care of adding the @@ -1574,7 +1583,7 @@ struct DiffTransposePass if (auto diffPairType = as<IRDifferentialPairType>(revVal->getDataType())) { revVal = builder->emitDifferentialPairGetDifferential( - (IRType*)diffTypeContext.getDifferentialTypeFromDiffPairType( + (IRType*)diffTypeContext.getDiffTypeFromPairType( builder, diffPairType), revVal); } @@ -1992,6 +2001,110 @@ struct DiffTransposePass fwdInst))); } + TranspositionResult transposeMakeExistential(IRBuilder* builder, IRInst* fwdInst, IRInst* revValue) + { + auto isExistentialType = [&](IRInst* type) -> bool + { + switch (type->getOp()) + { + case kIROp_ExtractExistentialType: + case kIROp_LookupWitness: + return true; + default: + return false; + } + }; + + auto diffType = fwdInst->getOperand(0)->getDataType(); + if (isExistentialType(diffType)) + { + // (A:IDiff = MakeExistential(B, W)) -> (dB: T += ExtractExistentialValue(dW)) + return TranspositionResult( + List<RevGradient>( + RevGradient( + RevGradient::Flavor::Simple, + fwdInst->getOperand(0), + builder->emitExtractExistentialValue( + fwdInst->getOperand(0)->getDataType(), + revValue), + fwdInst))); + } + else + { + // We have a concrete type. + // (A:IDiff = MakeExistential(B, W)) -> + // (dB: T += ExtractExistentialValue(Reinterpret(dW))) + auto diffValInDiffType = builder->emitReinterpret( + diffType, + builder->emitExtractExistentialValue( + builder->emitExtractExistentialType(revValue), + revValue)); + + return TranspositionResult( + List<RevGradient>( + RevGradient( + RevGradient::Flavor::Simple, + fwdInst->getOperand(0), + diffValInDiffType, + fwdInst))); + } + } + + TranspositionResult transposeExtractExistentialValue(IRBuilder* builder, IRInst* fwdInst, IRInst* revValue) + { + auto primalType = tryGetPrimalTypeFromDiffInst(fwdInst); + SLANG_ASSERT(primalType); + + // If we reach this point, revValue must be a differentiable type. + auto revTypeWitness = diffTypeContext.tryGetDifferentiableWitness( + builder, + primalType); + SLANG_ASSERT(revTypeWitness); + + auto baseExistential = fwdInst->getOperand(0); + + // (dA = ExtractExistentialValue(dB)) -> (dB += MakeExistential(T, A, ExtractExistentialWitness(B))) + return TranspositionResult( + List<RevGradient>( + RevGradient( + RevGradient::Flavor::Simple, + baseExistential, + builder->emitMakeExistential( + baseExistential->getDataType(), + revValue, + revTypeWitness), + fwdInst))); + } + + TranspositionResult transposeReinterpret(IRBuilder* builder, IRInst* fwdInst, IRInst* revValue) + { + // (A = reinterpret<T, U>(B)) -> (dB += reinterpret<U, T>(dA)) + return TranspositionResult( + List<RevGradient>( + RevGradient( + RevGradient::Flavor::Simple, + fwdInst->getOperand(0), + builder->emitReinterpret( + fwdInst->getOperand(0)->getDataType(), + revValue), + fwdInst))); + } + + + TranspositionResult transposePackAnyValue(IRBuilder* builder, IRInst* fwdInst, IRInst* revValue) + { + // (A = packAnyValue<T, U>(B)) -> (dB += unpackAnyValue<U, T>(dA)) + return TranspositionResult( + List<RevGradient>( + RevGradient( + RevGradient::Flavor::Simple, + fwdInst->getOperand(0), + builder->emitUnpackAnyValue( + fwdInst->getOperand(0)->getDataType(), + revValue), + fwdInst))); + } + // Gather all reverse-mode gradients for a Load inst, aggregate them and store them in the ptr. // void accumulateGradientsForLoad(IRBuilder* builder, IRLoad* revLoad) @@ -2681,13 +2794,18 @@ struct DiffTransposePass { // Look for differential inst decoration. if (auto diffInstDecoration = diffInst->findDecoration<IRDifferentialInstDecoration>()) - { return diffInstDecoration->getPrimalType(); - } - else - { - return nullptr; - } + + return nullptr; + } + + IRInst* tryGetWitnessFromDiffInst(IRInst* diffInst) + { + // Look for differential inst decoration. + if (auto diffInstDecoration = diffInst->findDecoration<IRDifferentialInstDecoration>()) + return diffInstDecoration->getWitness(); + + return nullptr; } IRInst* emitDZeroOfDiffInstType(IRBuilder* builder, IRType* primalType) @@ -2709,6 +2827,16 @@ struct DiffTransposePass auto diffDiffPairType = builder->getDifferentialPairUserCodeType(diffType, diffWitness); return builder->emitMakeDifferentialPairUserCode(diffDiffPairType, primalZero, diffZero); } + else if (as<IRInterfaceType>(primalType) || as<IRAssociatedType>(primalType)) + { + // Pack a null value into an existential type. + auto existentialZero = builder->emitMakeExistential( + autodiffContext->differentiableInterfaceType, + diffTypeContext.emitNullDifferential(builder), + autodiffContext->nullDifferentialWitness); + + return existentialZero; + } auto zeroMethod = diffTypeContext.getZeroMethodForType(builder, primalType); @@ -2720,6 +2848,19 @@ struct DiffTransposePass zeroMethod, List<IRInst*>()); } + + IRInst* emitDAddForExistentialType(IRBuilder* builder, IRType* primalType, IRInst* op1, IRInst* op2) + { + auto existentialDAddFunc = diffTypeContext.getOrCreateExistentialDAddMethod(); + + // Should exist. + SLANG_ASSERT(existentialDAddFunc); + + return builder->emitCallInst( + (IRType*)diffTypeContext.getDifferentialForType(builder, primalType), + existentialDAddFunc, + List<IRInst*>({ op1, op2 })); + } IRInst* emitDAddOfDiffInstType(IRBuilder* builder, IRType* primalType, IRInst* op1, IRInst* op2) { @@ -2764,6 +2905,13 @@ struct DiffTransposePass auto diffDiffPairType = builder->getDifferentialPairUserCodeType(diffType, diffWitness); return builder->emitMakeDifferentialPairUserCode(diffDiffPairType, primal, diff); } + else if (as<IRInterfaceType>(primalType) || as<IRAssociatedType>(primalType)) + { + // If our type is existential, we need to handle the case where + // one or both of our operands are null-type. + // + return emitDAddForExistentialType(builder, primalType, op1, op2); + } auto addMethod = diffTypeContext.getAddMethodForType(builder, primalType); diff --git a/source/slang/slang-ir-autodiff-unzip.cpp b/source/slang/slang-ir-autodiff-unzip.cpp index 95ad0d921..2857424f9 100644 --- a/source/slang/slang-ir-autodiff-unzip.cpp +++ b/source/slang/slang-ir-autodiff-unzip.cpp @@ -417,7 +417,7 @@ IRFunc* DiffUnzipPass::extractPrimalFunc( inst->getFullType(), intermediateVar, structKeyDecor->getStructKey()); - iuse->set(val); + builder.replaceOperand(iuse, val); } } instsToRemove.add(inst); diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h index 4846fc840..c57dc300f 100644 --- a/source/slang/slang-ir-autodiff-unzip.h +++ b/source/slang/slang-ir-autodiff-unzip.h @@ -88,7 +88,7 @@ struct DiffUnzipPass } if (auto pairType = as<IRDifferentialPairType>(type)) { - IRInst* diffType = diffTypeContext.getDifferentialTypeFromDiffPairType(builder, pairType); + IRInst* diffType = diffTypeContext.getDiffTypeFromPairType(builder, pairType); if (as<IRPtrTypeBase>(primalParam->getFullType())) diffType = builder->getPtrType(primalParam->getFullType()->getOp(), (IRType*)diffType); auto primalRef = builder->emitPrimalParamRef(primalParam); @@ -286,7 +286,8 @@ struct DiffUnzipPass if (auto fwdPairResultType = as<IRDifferentialPairType>(mixedDecoration->getPairType())) { primalType = fwdPairResultType->getValueType(); - diffType = (IRType*)diffTypeContext.getDifferentialForType(&globalBuilder, primalType); + diffType = (IRType*)diffTypeContext.getDiffTypeFromPairType(&globalBuilder, fwdPairResultType); + SLANG_ASSERT(diffType); resultType = fwdPairResultType; } diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index cb710ac6b..645662caa 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -35,6 +35,15 @@ IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementK return entry->getSatisfyingVal(); } } + else if (auto interfaceType = as<IRInterfaceType>(witness)) + { + for (UIndex ii = 0; ii < interfaceType->getOperandCount(); ii++) + { + auto entry = cast<IRInterfaceRequirementEntry>(interfaceType->getOperand(ii)); + if (entry->getRequirementKey() == requirementKey) + return entry->getRequirementVal(); + } + } else { return builder->emitLookupInterfaceMethodInst( @@ -47,8 +56,17 @@ IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementK static IRInst* _getDiffTypeFromPairType(AutoDiffSharedContext*sharedContext, IRBuilder* builder, IRDifferentialPairTypeBase* type) { - auto witnessTable = type->getWitness(); - return _lookupWitness(builder, witnessTable, sharedContext->differentialAssocTypeStructKey); + auto witness = type->getWitness(); + SLANG_RELEASE_ASSERT(witness); + + // Special case when the primal type is an InterfaceType/AssociatedType + if (as<IRInterfaceType>(type->getValueType()) || as<IRAssociatedType>(type->getValueType())) + { + // The differential type is the IDifferentiable interface type. + return sharedContext->differentiableInterfaceType; + } + + return _lookupWitness(builder, witness, sharedContext->differentialAssocTypeStructKey); } static IRInst* _getDiffTypeWitnessFromPairType(AutoDiffSharedContext* sharedContext, IRBuilder* builder, IRDifferentialPairTypeBase* type) @@ -332,6 +350,8 @@ AutoDiffSharedContext::AutoDiffSharedContext(IRModuleInst* inModuleInst) zeroMethodStructKey = findZeroMethodStructKey(); addMethodStructKey = findAddMethodStructKey(); mulMethodStructKey = findMulMethodStructKey(); + nullDifferentialStructType = findNullDifferentialStructType(); + nullDifferentialWitness = findNullDifferentialWitness(); if (differentialAssocTypeStructKey) isInterfaceAvailable = true; @@ -362,6 +382,47 @@ IRInst* AutoDiffSharedContext::findDifferentiableInterface() return nullptr; } +IRStructType* AutoDiffSharedContext::findNullDifferentialStructType() +{ + if (auto module = as<IRModuleInst>(moduleInst)) + { + for (auto globalInst : module->getGlobalInsts()) + { + // TODO: Also a particularly dangerous way to look for a struct... + if (auto structType = as<IRStructType>(globalInst)) + { + if (auto decor = structType->findDecoration<IRNameHintDecoration>()) + { + if (decor->getName() == toSlice("NullDifferential")) + { + return structType; + } + } + } + } + } + return nullptr; +} + +IRInst* AutoDiffSharedContext::findNullDifferentialWitness() +{ + if (auto module = as<IRModuleInst>(moduleInst)) + { + for (auto globalInst : module->getGlobalInsts()) + { + if (auto witnessTable = as<IRWitnessTable>(globalInst)) + { + if (witnessTable->getConformanceType() == differentiableInterfaceType + && witnessTable->getConcreteType() == nullDifferentialStructType) + return witnessTable; + } + } + } + + return nullptr; +} + + IRStructKey* AutoDiffSharedContext::getIDifferentiableStructKeyAtIndex(UInt index) { if (as<IRModuleInst>(moduleInst) && differentiableInterfaceType) @@ -442,11 +503,9 @@ IRInst* DifferentiableTypeConformanceContext::lookUpInterfaceMethod(IRBuilder* b } IRInst* DifferentiableTypeConformanceContext::getDifferentialTypeFromDiffPairType( - IRBuilder* builder, IRDifferentialPairTypeBase* diffPairType) + IRBuilder*, IRDifferentialPairTypeBase*) { - auto witness = diffPairType->getWitness(); - SLANG_RELEASE_ASSERT(witness); - return _lookupWitness(builder, witness, sharedContext->differentialAssocTypeStructKey); + SLANG_UNIMPLEMENTED_X(""); } IRInst* DifferentiableTypeConformanceContext::getDiffTypeFromPairType(IRBuilder* builder, IRDifferentialPairTypeBase* type) @@ -471,6 +530,189 @@ IRInst* DifferentiableTypeConformanceContext::getDiffAddMethodFromPairType(IRBui return _lookupWitness(builder, witnessTable, sharedContext->addMethodStructKey); } +IRInst *DifferentiableTypeConformanceContext::tryExtractConformanceFromInterfaceType(IRBuilder *builder, IRInterfaceType *interfaceType, IRWitnessTable *witnessTable) +{ + SLANG_RELEASE_ASSERT(interfaceType); + + List<IRInterfaceRequirementEntry*> lookupKeyPath = findDifferentiableInterfaceLookupPath( + sharedContext->differentiableInterfaceType, interfaceType); + + IRInst* differentialTypeWitness = witnessTable; + if (lookupKeyPath.getCount()) + { + // `interfaceType` does conform to `IDifferentiable`. + for (auto node : lookupKeyPath) + { + differentialTypeWitness = builder->emitLookupInterfaceMethodInst((IRType*)node->getRequirementVal(), differentialTypeWitness, node->getRequirementKey()); + // Lookup insts are always primal values. + builder->markInstAsPrimal(differentialTypeWitness); + } + return differentialTypeWitness; + } + + return nullptr; +} + +// Given an interface type, return the lookup path from a witness table of `type` to a witness table of `IDifferentiable`. +static bool _findDifferentiableInterfaceLookupPathImpl( + HashSet<IRInst*>& processedTypes, + IRInterfaceType* idiffType, + IRInterfaceType* type, + List<IRInterfaceRequirementEntry*>& currentPath) +{ + if (processedTypes.contains(type)) + return false; + processedTypes.add(type); + + List<IRInterfaceRequirementEntry*> lookupKeyPath; + for (UInt i = 0; i < type->getOperandCount(); i++) + { + auto entry = as<IRInterfaceRequirementEntry>(type->getOperand(i)); + if (!entry) continue; + if (auto wt = as<IRWitnessTableTypeBase>(entry->getRequirementVal())) + { + currentPath.add(entry); + if (wt->getConformanceType() == idiffType) + { + return true; + } + else if (auto subInterfaceType = as<IRInterfaceType>(wt->getConformanceType())) + { + if (_findDifferentiableInterfaceLookupPathImpl(processedTypes, idiffType, subInterfaceType, currentPath)) + return true; + } + currentPath.removeLast(); + } + } + return false; +} + +List<IRInterfaceRequirementEntry *> DifferentiableTypeConformanceContext::findDifferentiableInterfaceLookupPath(IRInterfaceType *idiffType, IRInterfaceType *type) +{ + List<IRInterfaceRequirementEntry*> currentPath; + HashSet<IRInst*> processedTypes; + _findDifferentiableInterfaceLookupPathImpl(processedTypes, idiffType, type, currentPath); + return currentPath; +} + +IRFunc *DifferentiableTypeConformanceContext::getOrCreateExistentialDAddMethod() +{ + if (this->existentialDAddFunc) + return this->existentialDAddFunc; + + SLANG_ASSERT(sharedContext->differentiableInterfaceType); + SLANG_ASSERT(sharedContext->nullDifferentialWitness); + + auto builder = IRBuilder(this->sharedContext->moduleInst); + + existentialDAddFunc = builder.createFunc(); + existentialDAddFunc->setFullType(builder.getFuncType( + List<IRType*>({ + sharedContext->differentiableInterfaceType, + sharedContext->differentiableInterfaceType, + }), + sharedContext->differentiableInterfaceType)); + + builder.setInsertInto(existentialDAddFunc); + auto entryBlock = builder.emitBlock(); + + builder.setInsertInto(entryBlock); + + // Insert parameters. + auto aObj = builder.emitParam(sharedContext->differentiableInterfaceType); + auto bObj = builder.emitParam(sharedContext->differentiableInterfaceType); + + // Check if a.type == null_differential.type + auto aObjWitnessIsNull = builder.emitIsDifferentialNull(aObj); + + // If aObjWitnessTable is null, return bObj. + auto aObjWitnessIsNullBlock = builder.emitBlock(); + builder.setInsertInto(aObjWitnessIsNullBlock); + builder.emitReturn(bObj); + + auto aObjWitnessIsNotNullBlock = builder.emitBlock(); + builder.setInsertInto(aObjWitnessIsNotNullBlock); + + // Check if b.type == null_differential.type + auto bObjWitnessIsNull = builder.emitIsDifferentialNull(bObj); + + // If bObjWitnessTable is null, return aObj. + auto bObjWitnessIsNullBlock = builder.emitBlock(); + builder.setInsertInto(bObjWitnessIsNullBlock); + builder.emitReturn(aObj); + + auto bObjWitnessIsNotNullBlock = builder.emitBlock(); + + // Emit aObj.type::dadd(aObj.val, bObj.val) + // + // Important: we're looking up dadd on the differential type, and + // not the primal type. This assumes that the two methods are identical, + // which (mathematically) they should be. + // + auto concreteDiffTypeWitnessTable = builder.emitExtractExistentialWitnessTable(aObj); + + // Extract func type from the witness table type. + IRFuncType* dAddFuncType = nullptr; + for (UIndex ii = 0; ii < sharedContext->differentiableInterfaceType->getOperandCount(); ii++) + { + auto entry = cast<IRInterfaceRequirementEntry>(sharedContext->differentiableInterfaceType->getOperand(ii)); + if (entry->getRequirementKey() == sharedContext->addMethodStructKey) + { + dAddFuncType = cast<IRFuncType>(entry->getRequirementVal()); + break; + } + } + + SLANG_ASSERT(dAddFuncType); + + auto dAddMethod = builder.emitLookupInterfaceMethodInst( + dAddFuncType, + concreteDiffTypeWitnessTable, + sharedContext->addMethodStructKey); + + // Call + auto dAddResult = builder.emitCallInst( + dAddFuncType->getResultType(), + dAddMethod, + List<IRInst*>({ + builder.emitExtractExistentialValue(dAddFuncType->getParamType(0), aObj), + builder.emitExtractExistentialValue(dAddFuncType->getParamType(1), bObj)})); + + // Wrap result in existential. + auto existentialDiffType = builder.emitMakeExistential( + sharedContext->differentiableInterfaceType, + dAddResult, + concreteDiffTypeWitnessTable); + + builder.emitReturn(existentialDiffType); + + // Emit an unreachable block to act as the after block. + auto unreachableBlock = builder.emitBlock(); + builder.setInsertInto(unreachableBlock); + builder.emitUnreachable(); + + // Link up conditional blocks. + builder.setInsertInto(entryBlock); + builder.emitIfElse( + aObjWitnessIsNull, + aObjWitnessIsNullBlock, + aObjWitnessIsNotNullBlock, + unreachableBlock); + + builder.setInsertInto(aObjWitnessIsNotNullBlock); + builder.emitIfElse( + bObjWitnessIsNull, + bObjWitnessIsNullBlock, + bObjWitnessIsNotNullBlock, + unreachableBlock); + + builder.addNameHintDecoration(existentialDAddFunc, UnownedStringSlice("__existential_dadd")); + builder.addBackwardDifferentiableDecoration(existentialDAddFunc); + + this->existentialDAddFunc = existentialDAddFunc; + return existentialDAddFunc; +} + void DifferentiableTypeConformanceContext::buildGlobalWitnessDictionary() { for (auto globalInst : sharedContext->moduleInst->getChildren()) @@ -745,9 +987,20 @@ IRInst* DifferentiableTypeConformanceContext::getArrayWitness(IRBuilder* builder return table; } -IRInst* DifferentiableTypeConformanceContext::getExtractExistensialTypeWitness(IRBuilder*, IRExtractExistentialType*) +IRInst* DifferentiableTypeConformanceContext::getExtractExistensialTypeWitness( + IRBuilder* builder, + IRExtractExistentialType* extractExistentialType) { - SLANG_UNIMPLEMENTED_X("TODO: Implement"); + // Check that the type's base is differentiable + if (differentiateType(builder, extractExistentialType->getOperand(0)->getDataType())) + { + return tryExtractConformanceFromInterfaceType( + builder, + cast<IRInterfaceType>(extractExistentialType->getOperand(0)->getDataType()), + (IRWitnessTable*)builder->emitExtractExistentialWitnessTable(extractExistentialType->getOperand(0))); + } + + return nullptr; } @@ -1761,6 +2014,71 @@ void removeDetachInsts(IRModule* module) pass.processModule(); } +struct LowerNullCheckPass : InstPassBase +{ + LowerNullCheckPass(IRModule* module, AutoDiffSharedContext* context) : + InstPassBase(module), context(context) + { + } + void processModule() + { + List<IRInst*> nullCheckInsts; + processInstsOfType<IRIsDifferentialNull>(kIROp_IsDifferentialNull, [&](IRIsDifferentialNull* isDiffNullInst) + { + IRBuilder builder(module); + builder.setInsertBefore(isDiffNullInst); + + // Extract existential type from the operand. + auto operand = isDiffNullInst->getBase(); + auto operandConcreteWitness = builder.emitExtractExistentialWitnessTable(operand); + auto witnessID = builder.emitGetSequentialIDInst(operandConcreteWitness); + + auto nullDiffWitnessTable = context->nullDifferentialWitness; + auto nullDiffWitnessID = builder.emitGetSequentialIDInst(nullDiffWitnessTable); + + // Compare the concrete type with the null differential witness table. + auto isDiffNull = builder.emitEql(witnessID, nullDiffWitnessID); + + isDiffNullInst->replaceUsesWith(isDiffNull); + nullCheckInsts.add(isDiffNullInst); + }); + + for (auto nullCheckInst : nullCheckInsts) + { + nullCheckInst->removeAndDeallocate(); + } + } + + private: + AutoDiffSharedContext* context; +}; + +void lowerNullCheckInsts(IRModule* module, AutoDiffSharedContext* context) +{ + LowerNullCheckPass pass(module, context); + pass.processModule(); +} + +void releaseNullDifferentialType(AutoDiffSharedContext* context) +{ + if (auto nullStruct = context->nullDifferentialStructType) + { + if (auto publicDecoration = nullStruct->findDecoration<IRPublicDecoration>()) + publicDecoration->removeAndDeallocate(); + if (auto keepAliveDecoration = nullStruct->findDecoration<IRKeepAliveDecoration>()) + keepAliveDecoration->removeAndDeallocate(); + } + + if (auto nullWitness = context->nullDifferentialWitness) + { + if (auto publicDecoration = nullWitness->findDecoration<IRPublicDecoration>()) + publicDecoration->removeAndDeallocate(); + if (auto keepAliveDecoration = nullWitness->findDecoration<IRKeepAliveDecoration>()) + keepAliveDecoration->removeAndDeallocate(); + } + +} + bool finalizeAutoDiffPass(IRModule* module) { bool modified = false; @@ -1777,17 +2095,25 @@ bool finalizeAutoDiffPass(IRModule* module) removeDetachInsts(module); + lowerNullCheckInsts(module, &autodiffContext); + stripNoDiffTypeAttribute(module); // Remove auto-diff related decorations. stripAutoDiffDecorations(module); + // Remove keep-alive decorations from null-differential type + // so it can be DCE'd if unused. + // + releaseNullDifferentialType(&autodiffContext); + return modified; } IRBlock* getBlock(IRInst* inst) { - SLANG_RELEASE_ASSERT(inst); + if (!inst) + return nullptr; if (auto block = as<IRBlock>(inst)) return block; diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h index fdbf5c65e..be51fba6f 100644 --- a/source/slang/slang-ir-autodiff.h +++ b/source/slang/slang-ir-autodiff.h @@ -92,6 +92,16 @@ struct AutoDiffSharedContext IRStructKey* mulMethodStructKey = nullptr; + // Refernce to NullDifferential struct type. These are used + // as sentinel values for uninitialized existential (interface-typed) + // differentials. + // + IRStructType* nullDifferentialStructType = nullptr; + + // Reference to the NullDifferential : IDifferentiable witness. + // + IRInst* nullDifferentialWitness = nullptr; + // Modules that don't use differentiable types // won't have the IDifferentiable interface type available. @@ -109,6 +119,10 @@ private: IRInst* findDifferentiableInterface(); + IRStructType *findNullDifferentialStructType(); + + IRInst *findNullDifferentialWitness(); + IRStructKey* findDifferentialTypeStructKey() { return getIDifferentiableStructKeyAtIndex(0); @@ -144,9 +158,17 @@ struct DifferentiableTypeConformanceContext IRGlobalValueWithCode* parentFunc = nullptr; OrderedDictionary<IRType*, IRInst*> differentiableWitnessDictionary; + IRFunc* existentialDAddFunc = nullptr; + DifferentiableTypeConformanceContext(AutoDiffSharedContext* shared) : sharedContext(shared) - {} + { + // Populate dictionary with null differential type. + if (sharedContext->nullDifferentialStructType) + differentiableWitnessDictionary.add( + sharedContext->nullDifferentialStructType, + sharedContext->nullDifferentialWitness); + } void setFunc(IRGlobalValueWithCode* func); @@ -181,6 +203,15 @@ struct DifferentiableTypeConformanceContext IRInst* getDiffAddMethodFromPairType(IRBuilder* builder, IRDifferentialPairTypeBase* type); + IRInst* tryExtractConformanceFromInterfaceType( + IRBuilder* builder, + IRInterfaceType* interfaceType, + IRWitnessTable* witnessTable); + + List<IRInterfaceRequirementEntry*> findDifferentiableInterfaceLookupPath( + IRInterfaceType* idiffType, + IRInterfaceType* type); + // Lookup and return the 'Differential' type declared in the concrete type // in order to conform to the IDifferentiable interface. // Note that inside a generic block, this will be a witness table lookup instruction @@ -190,6 +221,13 @@ struct DifferentiableTypeConformanceContext { switch (origType->getOp()) { + case kIROp_InterfaceType: + { + if (isDifferentiableType(origType)) + return this->sharedContext->differentiableInterfaceType; + else + return nullptr; + } case kIROp_ArrayType: { auto diffElementType = (IRType*)getDifferentialForType( @@ -249,6 +287,17 @@ struct DifferentiableTypeConformanceContext auto result = lookUpInterfaceMethod(builder, origType, sharedContext->addMethodStructKey); return result; } + + IRInst* emitNullDifferential(IRBuilder* builder) + { + return builder->emitCallInst( + sharedContext->nullDifferentialStructType, + getZeroMethodForType(builder, sharedContext->nullDifferentialStructType), + List<IRInst*>()); + } + + IRFunc* getOrCreateExistentialDAddMethod(); + }; struct DifferentialPairTypeBuilder diff --git a/source/slang/slang-ir-clone.cpp b/source/slang/slang-ir-clone.cpp index 9258e511f..43b60d6ed 100644 --- a/source/slang/slang-ir-clone.cpp +++ b/source/slang/slang-ir-clone.cpp @@ -267,7 +267,15 @@ IRInst* cloneInst( env, builder, oldInst); env->mapOldValToNew.add(oldInst, newInst); - + + // For hoistable insts, its possible that the cloned inst is the same + // as the original inst. + // Skip the decoration/children cloning in that case (which will end up + // in an infinite loop) + // + if (newInst == oldInst) + return newInst; + cloneInstDecorationsAndChildren( env, builder->getModule(), oldInst, newInst); diff --git a/source/slang/slang-ir-generics-lowering-context.h b/source/slang/slang-ir-generics-lowering-context.h index c8a7be3ee..509df6a33 100644 --- a/source/slang/slang-ir-generics-lowering-context.h +++ b/source/slang/slang-ir-generics-lowering-context.h @@ -33,7 +33,6 @@ namespace Slang Dictionary<IRInterfaceType*, IRInterfaceType*> loweredInterfaceTypes; Dictionary<IRInterfaceType*, IRInterfaceType*> mapLoweredInterfaceToOriginal; - // Dictionaries for interface type requirement key-value lookups. // Used by `findInterfaceRequirementVal`. Dictionary<IRInterfaceType*, Dictionary<IRInst*, IRInst*>> mapInterfaceRequirementKeyValue; diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index f110c07e7..a8fdd8202 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -362,6 +362,9 @@ INST(PrimalParamRef, PrimalParamRef, 1, 0) // to represent a reference to an inout parameter for use in the back-prop part of the computation. INST(DiffParamRef, DiffParamRef, 1, 0) +// Check that the value is a differential null value. +INST(IsDifferentialNull, IsDifferentialNull, 1, 0) + INST(FieldExtract, get_field, 2, 0) INST(FieldAddress, get_field_addr, 2, 0) @@ -935,8 +938,8 @@ INST(WrapExistential, wrapExistential, 1, 0) INST(GetValueFromBoundInterface, getValueFromBoundInterface, 1, 0) INST(ExtractExistentialValue, extractExistentialValue, 1, 0) -INST(ExtractExistentialType, extractExistentialType, 1, 0) -INST(ExtractExistentialWitnessTable, extractExistentialWitnessTable, 1, 0) +INST(ExtractExistentialType, extractExistentialType, 1, HOISTABLE) +INST(ExtractExistentialWitnessTable, extractExistentialWitnessTable, 1, HOISTABLE) INST(ExtractTaggedUnionTag, extractTaggedUnionTag, 1, 0) INST(ExtractTaggedUnionPayload, extractTaggedUnionPayload, 1, 0) diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index dade0e2f4..adfcac7fd 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -819,6 +819,7 @@ struct IRDifferentialInstDecoration : IRAutodiffInstDecoration IRType* getPrimalType() { return (IRType*)(getOperand(0)); } IRInst* getPrimalInst() { return getOperand(1); } + IRInst* getWitness() { return getOperand(2); } }; struct IRPrimalInstDecoration : IRAutodiffInstDecoration @@ -1018,6 +1019,17 @@ struct IRBackwardDifferentiate : IRInst IR_LEAF_ISA(BackwardDifferentiate) }; +struct IRIsDifferentialNull : IRInst +{ + enum + { + kOp = kIROp_IsDifferentialNull + }; + IRInst* getBase() { return getOperand(0); } + + IR_LEAF_ISA(IsDifferentialNull) +}; + // Retrieves the primal substitution function for the given function. struct IRPrimalSubstitute : IRInst { @@ -3223,6 +3235,7 @@ public: IRInst* emitBackwardDifferentiatePropagateInst(IRType* type, IRInst* baseFn); IRInst* emitPrimalSubstituteInst(IRType* type, IRInst* baseFn); IRInst* emitDetachDerivative(IRType* type, IRInst* value); + IRInst* emitIsDifferentialNull(IRInst* value); IRInst* emitDispatchKernelInst(IRType* type, IRInst* baseFn, IRInst* threadGroupSize, IRInst* dispatchSize, Int argCount, IRInst* const* inArgs); IRInst* emitCudaKernelLaunch(IRInst* baseFn, IRInst* gridDim, IRInst* blockDim, IRInst* argsArray, IRInst* cudaStream); @@ -4177,6 +4190,12 @@ public: addDecoration(value, kIROp_DifferentialInstDecoration, primalType, primalInst); } + void markInstAsDifferential(IRInst* value, IRType* primalType, IRInst* primalInst, IRInst* witnessTable) + { + IRInst* args[] = { primalType, primalInst, witnessTable }; + addDecoration(value, kIROp_DifferentialInstDecoration, args, 3); + } + void addCOMWitnessDecoration(IRInst* value, IRInst* witnessTable) { addDecoration(value, kIROp_COMWitnessDecoration, &witnessTable, 1); diff --git a/source/slang/slang-ir-lower-generics.cpp b/source/slang/slang-ir-lower-generics.cpp index f535b97d2..19f713548 100644 --- a/source/slang/slang-ir-lower-generics.cpp +++ b/source/slang/slang-ir-lower-generics.cpp @@ -2,6 +2,7 @@ #include "slang-ir-lower-generics.h" #include "slang-ir-any-value-marshalling.h" +#include "slang-ir-any-value-inference.h" #include "slang-ir-augment-make-existential.h" #include "slang-ir-generics-lowering-context.h" #include "slang-ir-lower-existential.h" @@ -15,7 +16,10 @@ #include "slang-ir-witness-table-wrapper.h" #include "slang-ir-ssa-simplification.h" #include "slang-ir-util.h" +#include "slang-ir-layout.h" + #include "../core/slang-performance-profiler.h" +#include "../core/slang-func-ptr.h" namespace Slang { @@ -213,6 +217,8 @@ namespace Slang checkTypeConformanceExists(&sharedContext); + inferAnyValueSizeWhereNecessary(module); + // Replace all `makeExistential` insts with `makeExistentialWithRTTI` // before making any other changes. This is necessary because a parameter of // generic type will be lowered into `AnyValueType`, and after that we can no longer diff --git a/source/slang/slang-ir-lower-reinterpret.cpp b/source/slang/slang-ir-lower-reinterpret.cpp index 7575c8f12..689cc8505 100644 --- a/source/slang/slang-ir-lower-reinterpret.cpp +++ b/source/slang/slang-ir-lower-reinterpret.cpp @@ -3,6 +3,7 @@ #include "slang-ir-insts.h" #include "slang-ir-layout.h" #include "slang-ir-any-value-marshalling.h" +#include "slang-ir-any-value-inference.h" namespace Slang { @@ -84,6 +85,11 @@ struct ReinterpretLoweringContext void lowerReinterpret(TargetRequest* targetReq, IRModule* module, DiagnosticSink* sink) { + // Before processing reinterpret insts, ensure that existential types without + // user-defined sizes have inferred sizes where possible. + // + inferAnyValueSizeWhereNecessary(module); + ReinterpretLoweringContext context; context.module = module; context.targetReq = targetReq; diff --git a/source/slang/slang-ir-ssa.cpp b/source/slang/slang-ir-ssa.cpp index 18eba677e..730943bf8 100644 --- a/source/slang/slang-ir-ssa.cpp +++ b/source/slang/slang-ir-ssa.cpp @@ -412,7 +412,7 @@ PhiInfo* addPhi( { valueType = context->getBuilder()->getRateQualifiedType(rate, valueType); } - IRParam* phi = builder->createParam(valueType); + IRParam* phi = builder->emitParam(valueType); cloneRelevantDecorations(var, phi); RefPtr<PhiInfo> phiInfo = new PhiInfo(); @@ -503,6 +503,7 @@ IRInst* tryRemoveTrivialPhi( // replace uses of the phi (including its possible uses // of itself) with the unique non-phi value. phi->replaceUsesWith(same); + phi->removeAndDeallocate(); // Clear out the operands to the phi, since they won't // actually get used in the program any more. @@ -849,11 +850,12 @@ void processBlock( // leave them as-is, or replace them with a value // that we look up with local/global value numbering - IRInst* next = nullptr; - for (auto ii = block->getFirstInst(); ii; ii = next) - { - next = ii->getNextInst(); + List<IRInst*> workList; + for (auto ii = block->getFirstInst(); ii; ii = ii->getNextInst()) + workList.add(ii); + for (auto& ii : workList) + { // Any new instructions we create to represent // the new value will get inserted before whatever // instruction we are working with. @@ -1117,6 +1119,14 @@ bool constructSSA(ConstructSSAContext* context) { auto blockInfo = *context->blockInfos.tryGetValue(bb); + // First remove phis from their parent blocks. + for (auto phiInfo : blockInfo->phis) + if (!phiInfo->replacement) + phiInfo->phi->removeFromParent(); + + // Then, add them back in a consistent order, and add predecessor + // args in the same order. + // for (auto phiInfo : blockInfo->phis) { // If we replaced this phi with another value, diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index c666ccc08..8d36c2e86 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -3268,6 +3268,17 @@ namespace Slang return inst; } + IRInst *IRBuilder::emitIsDifferentialNull(IRInst *value) + { + auto inst = createInst<IRIsDifferentialNull>( + this, + kIROp_IsDifferentialNull, + getBoolType(), + value); + addInst(inst); + return inst; + } + IRInst* IRBuilder::emitBackwardDifferentiateInst(IRType* type, IRInst* baseFn) { auto inst = createInst<IRBackwardDifferentiate>( diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 1ed79fbe3..489a89287 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -2813,6 +2813,16 @@ void collectParameterLists( auto noDiffAttr = context->astBuilder->getNoDiffModifierVal(); thisType = context->astBuilder->getModifiedType(thisType, 1, &noDiffAttr); } + else if (auto fwdDerivDeclRef = declRef.as<ForwardDerivativeRequirementDecl>()) + { + thisType = fwdDerivDeclRef.getDecl()->diffThisType; + } + else if (auto bwdDerivDeclRef = declRef.as<BackwardDerivativeRequirementDecl>()) + { + thisType = bwdDerivDeclRef.getDecl()->diffThisType; + innerThisParamDirection = kParameterDirection_InOut; + } + addThisParameter(innerThisParamDirection, thisType, ioParameterLists); } } @@ -7235,7 +7245,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> } auto assocType = context->irBuilder->getAssociatedType( constraintInterfaces.getArrayView().arrayView); - context->setValue(decl, assocType); + context->setValue(decl, assocType); return LoweredValInfo::simple(assocType); } @@ -8446,14 +8456,6 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> addNameHint(subContext, irFunc, decl); addLinkageDecoration(subContext, irFunc, decl); - if (decl->body) - { - if (auto differentialAttr = decl->findModifier<DifferentiableAttribute>()) - { - lowerDifferentiableAttribute(subContext, irFunc, differentialAttr); - } - } - // Always force inline diff setter accessor to prevent downstream compiler from complaining // fields are not fully initialized for the first `inout` parameter. if (as<SetterDecl>(decl)) @@ -8927,6 +8929,17 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> getBuilder()->addDecoration(irFunc, kIROp_PreferRecomputeDecoration); } } + + if (auto diffAttr = decl->findModifier<DifferentiableAttribute>()) + { + if (decl->body) + { + subContext->irBuilder->setInsertInto(irFunc->getParent()); + lowerDifferentiableAttribute(subContext, irFunc, diffAttr); + subContext->irBuilder->setInsertInto(irFunc); + } + } + // For convenience, ensure that any additional global // values that were emitted while outputting the function // body appear before the function itself in the list |
