diff options
| author | Yong He <yonghe@outlook.com> | 2022-12-01 18:55:43 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-12-01 18:55:43 -0800 |
| commit | e7df8538eb8f0ed06f0838d946bec8e9e0fe0985 (patch) | |
| tree | 3c08e646600ab82ffda260f2b6deb96dd2085776 /source | |
| parent | f51f69d045d9e0b83d9ab1f4623d4319ce1867be (diff) | |
Allow `no_diff` on `this` parameter. (#2543)
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/diff.meta.slang | 5 | ||||
| -rw-r--r-- | source/slang/slang-ast-builder.cpp | 10 | ||||
| -rw-r--r-- | source/slang/slang-ast-builder.h | 6 | ||||
| -rw-r--r-- | source/slang/slang-ast-modifier.h | 7 | ||||
| -rw-r--r-- | source/slang/slang-check-conformance.cpp | 5 | ||||
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 132 | ||||
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 10 | ||||
| -rw-r--r-- | source/slang/slang-check-impl.h | 9 | ||||
| -rw-r--r-- | source/slang/slang-diagnostic-defs.h | 1 | ||||
| -rw-r--r-- | source/slang/slang-emit.cpp | 4 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.cpp | 187 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-pairs.cpp | 38 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff.cpp | 59 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff.h | 22 | ||||
| -rw-r--r-- | source/slang/slang-ir-hoist-local-types.cpp | 4 | ||||
| -rw-r--r-- | source/slang/slang-ir-hoist-local-types.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 5 | ||||
| -rw-r--r-- | source/slang/slang-syntax.cpp | 8 | ||||
| -rw-r--r-- | source/slang/slang-syntax.h | 5 |
20 files changed, 386 insertions, 135 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index 69ced9156..033c173ab 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -11,13 +11,16 @@ attribute_syntax [ForwardDerivative(function)] : ForwardDerivativeAttribute; __attributeTarget(FunctionDeclBase) attribute_syntax [BackwardDifferentiable] : BackwardDifferentiableAttribute; - __attributeTarget(FunctionDeclBase) attribute_syntax [ForwardDerivativeOf(function)] : ForwardDerivativeOfAttribute; __attributeTarget(DeclBase) attribute_syntax [DerivativeMember(memberName)] : DerivativeMemberAttribute; +// Exclude "this" parameter from differentiation. +__attributeTarget(FunctionDeclBase) +attribute_syntax [NoDiffThis] : NoDiffThisAttribute; + /// Pair type that serves to wrap the primal and /// differential types of an arbitrary type T. diff --git a/source/slang/slang-ast-builder.cpp b/source/slang/slang-ast-builder.cpp index 623a9161b..ab161065d 100644 --- a/source/slang/slang-ast-builder.cpp +++ b/source/slang/slang-ast-builder.cpp @@ -141,6 +141,16 @@ Type* SharedASTBuilder::getNoneType() return m_noneType; } +Type* SharedASTBuilder::getDiffInterfaceType() +{ + if (!m_diffInterfaceType) + { + auto decl = findMagicDecl("DifferentiableType"); + m_diffInterfaceType = DeclRefType::create(m_astBuilder, makeDeclRef<Decl>(decl)); + } + return m_diffInterfaceType; +} + SharedASTBuilder::~SharedASTBuilder() { // Release built in types.. diff --git a/source/slang/slang-ast-builder.h b/source/slang/slang-ast-builder.h index bdc03dda5..72d8ec50a 100644 --- a/source/slang/slang-ast-builder.h +++ b/source/slang/slang-ast-builder.h @@ -36,6 +36,8 @@ public: Type* getNullPtrType(); /// Get the NullPtr type Type* getNoneType(); + /// Get the `IDifferentiable` type + Type* getDiffInterfaceType(); const ReflectClassInfo* findClassInfo(Name* name); SyntaxClass<NodeBase> findSyntaxClass(Name* name); @@ -85,7 +87,7 @@ protected: Type* m_dynamicType = nullptr; Type* m_nullPtrType = nullptr; Type* m_noneType = nullptr; - Type* m_diffBottomType = nullptr; + Type* m_diffInterfaceType = nullptr; Type* m_builtinTypes[Index(BaseType::CountOf)]; Dictionary<String, Decl*> m_magicDecls; @@ -308,7 +310,7 @@ public: Type* getNullPtrType() { return m_sharedASTBuilder->getNullPtrType(); } Type* getNoneType() { return m_sharedASTBuilder->getNoneType(); } Type* getEnumTypeType() { return m_sharedASTBuilder->getEnumTypeType(); } - + Type* getDiffInterfaceType() { return m_sharedASTBuilder->getDiffInterfaceType(); } // Construct the type `Ptr<valueType>`, where `Ptr` // is looked up as a builtin type. PtrType* getPtrType(Type* valueType); diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index 2adbcf6c6..c85464061 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -1113,6 +1113,13 @@ class BackwardDerivativeOfAttribute : public DifferentiableAttribute Expr* backDeclRef; // DeclRef to this derivative function when initiated from primalFunction. }; + /// The `[NoDiffThis]` attribute is used to specify that the `this` parameter should not be + /// included for differentiation. +class NoDiffThisAttribute : public Attribute +{ + SLANG_AST_CLASS(NoDiffThisAttribute) +}; + /// Indicates that the modified declaration is one of the "magic" declarations /// that NVAPI uses to communicate extended operations. When NVAPI is being included /// via the prelude for downstream compilation, declarations with this modifier diff --git a/source/slang/slang-check-conformance.cpp b/source/slang/slang-check-conformance.cpp index 4d983b746..3a50897de 100644 --- a/source/slang/slang-check-conformance.cpp +++ b/source/slang/slang-check-conformance.cpp @@ -527,6 +527,11 @@ namespace Slang return false; } + bool SemanticsVisitor::isTypeDifferentiable(Type* type) + { + return isDeclaredSubtype(type, m_astBuilder->getDiffInterfaceType()); + } + Val* SemanticsVisitor::tryGetSubtypeWitness( Type* subType, DeclRef<AggTypeDecl> superTypeDeclRef) diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index d36e6286d..d8968e33a 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -340,6 +340,16 @@ namespace Slang return isEffectivelyStatic(decl, parentDecl); } + bool isGlobalDecl(Decl* decl) + { + if (!decl) + return false; + auto parentDecl = decl->parentDecl; + if (auto genericDecl = as<GenericDecl>(parentDecl)) + parentDecl = genericDecl->parentDecl; + return as<NamespaceDeclBase>(parentDecl) != nullptr; + } + /// Is `decl` a global shader parameter declaration? bool isGlobalShaderParameter(VarDeclBase* decl) { @@ -1920,37 +1930,21 @@ namespace Slang if(!requiredResultType->equals(satisfyingResultType)) return false; - witnessTable->add( - requiredMemberDeclRef.getDecl(), - RequirementWitness(satisfyingMemberDeclRef)); - if (hasForwardDerivative || hasBackwardDerivative) { - int fwdReqFound = 0; - int bwdReqFound = 0; - for (auto reqRefDecl : requiredMemberDeclRef.getDecl()->getMembersOfType<DerivativeRequirementReferenceDecl>()) + auto parentInterfaceDecl = as<InterfaceDecl>(getParentDecl(requiredMemberDeclRef.getDecl())); + if (parentInterfaceDecl) { - if (auto fwdReq = as<ForwardDerivativeRequirementDecl>(reqRefDecl->referencedDecl)) - { - ForwardDifferentiateVal* val = m_astBuilder->create<ForwardDifferentiateVal>(); - val->func = satisfyingMemberDeclRef; - witnessTable->add(fwdReq, RequirementWitness(val)); - fwdReqFound++; - } - else if (auto bwdReq = as<BackwardDerivativeRequirementDecl>(reqRefDecl->referencedDecl)) - { - BackwardDifferentiateVal* val = m_astBuilder->create<BackwardDifferentiateVal>(); - val->func = satisfyingMemberDeclRef; - witnessTable->add(bwdReq, RequirementWitness(val)); - bwdReqFound++; - } + auto idiffType = DeclRefType::create(m_astBuilder, m_astBuilder->getDifferentiableInterface()); + bool noDiffThisSatisfying = !isDeclaredSubtype(witnessTable->witnessedType, idiffType); + bool noDiffThisRequirement = (requiredMemberDeclRef.getDecl()->findModifier<NoDiffThisAttribute>() != nullptr); + if (noDiffThisRequirement != noDiffThisSatisfying) + return false; } - - SLANG_RELEASE_ASSERT( - fwdReqFound == (hasForwardDerivative ? 1 : 0) && - bwdReqFound == (hasBackwardDerivative ? 1 : 0)); } + _addMethodWitness(witnessTable, requiredMemberDeclRef, satisfyingMemberDeclRef); + return true; } @@ -2543,7 +2537,10 @@ namespace Slang // mangled name! // synFuncDecl->nameAndLoc = requiredMemberDeclRef.getDecl()->nameAndLoc; - + if (synFuncDecl->nameAndLoc.name) + { + synFuncDecl->nameAndLoc.name = getSession()->getNameObj("$__syn_" + synFuncDecl->nameAndLoc.name->text); + } // The result type of our synthesized method will be the expected // result type from the interface requirement. // @@ -2592,6 +2589,13 @@ namespace Slang synArg->declRef = makeDeclRef(synParamDecl); synArg->type = paramType; synArgs.add(synArg); + + if (paramDeclRef.getDecl()->findModifier<NoDiffModifier>()) + { + auto noDiffModifier = m_astBuilder->create<NoDiffModifier>(); + noDiffModifier->keywordName = getSession()->getNameObj("no_diff"); + addModifier(synParamDecl, noDiffModifier); + } } @@ -2625,13 +2629,52 @@ namespace Slang synThis->type.isLeftValue = true; auto synMutatingAttr = m_astBuilder->create<MutatingAttribute>(); - synFuncDecl->modifiers.first = synMutatingAttr; + addModifier(synFuncDecl, synMutatingAttr); + } + + if (requiredMemberDeclRef.getDecl()->hasModifier<NoDiffThisAttribute>()) + { + auto noDiffThisAttr = m_astBuilder->create<NoDiffThisAttribute>(); + addModifier(synFuncDecl, noDiffThisAttr); + } + if (requiredMemberDeclRef.getDecl()->hasModifier<ForwardDifferentiableAttribute>()) + { + auto attr = m_astBuilder->create<ForwardDifferentiableAttribute>(); + addModifier(synFuncDecl, attr); + } + if (requiredMemberDeclRef.getDecl()->hasModifier<BackwardDifferentiableAttribute>()) + { + auto attr = m_astBuilder->create<BackwardDifferentiableAttribute>(); + addModifier(synFuncDecl, attr); } } return synFuncDecl; } + void SemanticsVisitor::_addMethodWitness( + WitnessTable* witnessTable, + DeclRef<CallableDecl> requiredMemberDeclRef, + DeclRef<CallableDecl> satisfyingMemberDeclRef) + { + for (auto reqRefDecl : requiredMemberDeclRef.getDecl()->getMembersOfType<DerivativeRequirementReferenceDecl>()) + { + if (auto fwdReq = as<ForwardDerivativeRequirementDecl>(reqRefDecl->referencedDecl)) + { + ForwardDifferentiateVal* val = m_astBuilder->create<ForwardDifferentiateVal>(); + val->func = satisfyingMemberDeclRef; + witnessTable->add(fwdReq, RequirementWitness(val)); + } + else if (auto bwdReq = as<BackwardDerivativeRequirementDecl>(reqRefDecl->referencedDecl)) + { + BackwardDifferentiateVal* val = m_astBuilder->create<BackwardDifferentiateVal>(); + val->func = satisfyingMemberDeclRef; + witnessTable->add(bwdReq, RequirementWitness(val)); + } + } + witnessTable->add(requiredMemberDeclRef, RequirementWitness(satisfyingMemberDeclRef)); + } + bool SemanticsVisitor::trySynthesizeMethodRequirementWitness( ConformanceCheckingContext* context, LookupResult const& lookupResult, @@ -2806,8 +2849,7 @@ namespace Slang // difference between our synthetic method and a hand-written // one with the same behavior. // - witnessTable->add(requiredMemberDeclRef, - RequirementWitness(makeDeclRef(synFuncDecl))); + _addMethodWitness(witnessTable, requiredMemberDeclRef, makeDeclRef(synFuncDecl)); return true; } @@ -5593,6 +5635,7 @@ namespace Slang if (auto interfaceDecl = findParentInterfaceDecl(decl)) { + bool isDiffFunc = false; if (decl->hasModifier<ForwardDifferentiableAttribute>()) { auto reqDecl = m_astBuilder->create<ForwardDerivativeRequirementDecl>(); @@ -5607,6 +5650,7 @@ namespace Slang reqRef->referencedDecl = reqDecl; reqRef->parentDecl = decl; decl->members.add(reqRef); + isDiffFunc = true; } if (decl->hasModifier<BackwardDifferentiableAttribute>()) { @@ -5622,6 +5666,36 @@ namespace Slang reqRef->referencedDecl = reqDecl; reqRef->parentDecl = decl; decl->members.add(reqRef); + isDiffFunc = true; + } + if (isDiffFunc) + { + auto interfaceDeclRef = createDefaultSubstitutionsIfNeeded(m_astBuilder, this, makeDeclRef(interfaceDecl)); + auto interfaceType = DeclRefType::create(m_astBuilder, interfaceDeclRef); + bool noDiffThisRequirement = !isTypeDifferentiable(interfaceType); + if (noDiffThisRequirement) + { + auto noDiffThisModifier = m_astBuilder->create<NoDiffThisAttribute>(); + addModifier(decl, noDiffThisModifier); + } + } + } + if (decl->findModifier<DifferentiableAttribute>()) + { + // Add `no_diff` modifiers to parameters. + // This is necessary to preserve no-diff-ness for generic function before and after + // specialization. + for (auto paramDecl : decl->getParameters()) + { + if (paramDecl->type.type && !isTypeDifferentiable(paramDecl->type.type)) + { + if (!paramDecl->hasModifier<NoDiffModifier>()) + { + auto noDiffModifier = m_astBuilder->create<NoDiffModifier>(); + noDiffModifier->keywordName = getSession()->getNameObj("no_diff"); + addModifier(paramDecl, noDiffModifier); + } + } } } } diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 7297ca282..336682bf4 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -503,7 +503,11 @@ namespace Slang auto toBeSynthesized = m_astBuilder->create<ToBeSynthesizedModifier>(); addModifier(synthesizedDecl, toBeSynthesized); - return ConstructDeclRefExpr(makeDeclRef(synthesizedDecl), nullptr, originalExpr->loc, originalExpr); + return ConstructDeclRefExpr( + makeDeclRef(synthesizedDecl), + nullptr, + originalExpr ? originalExpr->loc : SourceLoc(), + originalExpr); } Expr* SemanticsVisitor::ConstructLookupResultExpr( @@ -1927,6 +1931,10 @@ namespace Slang { getSink()->diagnose(forwardDiff, Diagnostics::functionNotMarkedAsDifferentiable, funcDecl, "backward"); } + if (!isEffectivelyStatic(funcDecl) && !isGlobalDecl(funcDecl)) + { + getSink()->diagnose(forwardDiff, Diagnostics::nonStaticMemberFunctionNotAllowedAsDiffOperand, funcDecl); + } } } } diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 1c2f698bd..fb47a38c1 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -16,6 +16,8 @@ namespace Slang bool isEffectivelyStatic( Decl* decl); + bool isGlobalDecl(Decl* decl); + Type* checkProperType( Linkage* linkage, TypeExp typeExp, @@ -1026,6 +1028,11 @@ namespace Slang List<Expr*>& synArgs, ThisExpr*& synThis); + void _addMethodWitness( + WitnessTable* witnessTable, + DeclRef<CallableDecl> requirement, + DeclRef<CallableDecl> method); + /// Attempt to synthesize a method that can satisfy `requiredMemberDeclRef` using `lookupResult`. /// /// On success, installs the syntethesized method in `witnessTable` and returns `true`. @@ -1431,6 +1438,8 @@ namespace Slang bool isInterfaceType(Type* type); + bool isTypeDifferentiable(Type* type); + /// Check whether `subType` is a sub-type of `superTypeDeclRef`, /// and return a witness to the sub-type relationship if it holds /// (return null otherwise). diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index d293626ae..fc92241e1 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -305,6 +305,7 @@ DIAGNOSTIC(30095, Error, errorTypeOfCalleeIncompatibleWithCaller, "the error typ DIAGNOSTIC(30096, Error, differentialTypeShouldServeAsItsOwnDifferentialType, "type '$0' is used as a `Differential` type, therefore it must serve as its own `Differential` type.") DIAGNOSTIC(30097, Error, functionNotMarkedAsDifferentiable, "function '$0' is not marked as $1-differentiable.") +DIAGNOSTIC(30098, Error, nonStaticMemberFunctionNotAllowedAsDiffOperand, "non-static function reference '$0' is not allowed here.") DIAGNOSTIC(-1, Note, noteSeeUseOfDifferentialType, "see use of '$0' as Differential of '$1'.") diff --git a/source/slang/slang-emit.cpp b/source/slang/slang-emit.cpp index 508402736..2476f79e5 100644 --- a/source/slang/slang-emit.cpp +++ b/source/slang/slang-emit.cpp @@ -382,7 +382,9 @@ Result linkAndOptimizeIR( if (!changed) break; } - + + finalizeAutoDiffPass(irModule); + lowerReinterpret(targetRequest, irModule, sink); validateIRModuleIfEnabled(codeGenContext, irModule); diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index d45dd0c10..c94342736 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -11,6 +11,12 @@ namespace Slang { +static IRInst* _unwrapAttributedType(IRInst* type) +{ + while (auto attrType = as<IRAttributedType>(type)) + type = attrType->getBaseType(); + return type; +} DiagnosticSink* ForwardDerivativeTranscriber::getSink() { @@ -183,8 +189,12 @@ IRType* ForwardDerivativeTranscriber::getOrCreateDiffPairType(IRInst* primalType IRType* ForwardDerivativeTranscriber::getOrCreateDiffPairType(IRInst* primalType) { IRBuilder builder(sharedBuilder); - builder.setInsertInto(primalType->parent); - auto witness = as<IRWitnessTable>( + if (!primalType->next) + builder.setInsertInto(primalType->parent); + else + builder.setInsertBefore(primalType->next); + + IRInst* witness = as<IRWitnessTable>( differentiableTypeConformanceContext.lookUpConformanceForType((IRType*)primalType)); if (!witness) @@ -193,6 +203,10 @@ IRType* ForwardDerivativeTranscriber::getOrCreateDiffPairType(IRInst* primalType { witness = getDifferentialPairWitness(primalPairType); } + else if (auto extractExistential = as<IRExtractExistentialType>(primalType)) + { + differentiateExtractExistentialType(&builder, extractExistential, witness); + } } return builder.getDifferentialPairType( @@ -271,6 +285,12 @@ IRType* ForwardDerivativeTranscriber::_differentiateTypeImpl(IRBuilder* builder, else return nullptr; + case kIROp_ExtractExistentialType: + { + IRInst* wt = nullptr; + return differentiateExtractExistentialType(builder, as<IRExtractExistentialType>(primalType), wt); + } + case kIROp_TupleType: { auto tupleType = as<IRTupleType>(primalType); @@ -288,6 +308,75 @@ IRType* ForwardDerivativeTranscriber::_differentiateTypeImpl(IRBuilder* builder, } } + // Given an interface type, return the lookup path from a witness table of `type` to a witness table of `IDifferentiable`. +bool _findDifferentiableInterfaceLookupPathImpl( + HashSet<IRInst*>& processedTypes, + IRInterfaceType* idiffType, + IRInterfaceType* type, + List<IRInterfaceRequirementEntry*>& currentPath) +{ + if (processedTypes.Contains(type)) + return false; + processedTypes.Add(type); + + List<IRInterfaceRequirementEntry*> lookupKeyPath; + for (UInt i = 0; i < type->getOperandCount(); i++) + { + auto entry = as<IRInterfaceRequirementEntry>(type->getOperand(i)); + if (!entry) continue; + if (auto wt = as<IRWitnessTableTypeBase>(entry->getRequirementVal())) + { + currentPath.add(entry); + if (wt->getConformanceType() == idiffType) + { + return true; + } + else if (auto subInterfaceType = as<IRInterfaceType>(wt->getConformanceType())) + { + if (_findDifferentiableInterfaceLookupPathImpl(processedTypes, idiffType, subInterfaceType, currentPath)) + return true; + } + currentPath.removeLast(); + } + } + return false; +} + +List<IRInterfaceRequirementEntry*> _findDifferentiableInterfaceLookupPath( + IRInterfaceType* idiffType, + IRInterfaceType* type) +{ + List<IRInterfaceRequirementEntry*> currentPath; + HashSet<IRInst*> processedTypes; + _findDifferentiableInterfaceLookupPathImpl(processedTypes, idiffType, type, currentPath); + return currentPath; +} + +IRType* ForwardDerivativeTranscriber::differentiateExtractExistentialType(IRBuilder* builder, IRExtractExistentialType* origType, IRInst*& witnessTable) +{ + witnessTable = nullptr; + + // Search for IDifferentiable conformance. + auto interfaceType = as<IRInterfaceType>(_unwrapAttributedType(origType->getOperand(0)->getDataType())); + if (!interfaceType) + return nullptr; + List<IRInterfaceRequirementEntry*> lookupKeyPath = _findDifferentiableInterfaceLookupPath( + autoDiffSharedContext->differentiableInterfaceType, interfaceType); + + if (lookupKeyPath.getCount()) + { + // `interfaceType` does conform to `IDifferentiable`. + witnessTable = builder->emitExtractExistentialWitnessTable(origType->getOperand(0)); + for (auto node : lookupKeyPath) + { + witnessTable = builder->emitLookupInterfaceMethodInst((IRType*)node->getRequirementVal(), witnessTable, node->getRequirementKey()); + } + auto diffType = builder->emitLookupInterfaceMethodInst(builder->getTypeType(), witnessTable, autoDiffSharedContext->differentialAssocTypeStructKey); + return (IRType*)diffType; + } + return nullptr; +} + IRType* ForwardDerivativeTranscriber::tryGetDiffPairType(IRBuilder* builder, IRType* primalType) { // If this is a PtrType (out, inout, etc..), then create diff pair from @@ -699,6 +788,10 @@ InstPair ForwardDerivativeTranscriber::transcribeCall(IRBuilder* builder, IRCall return InstPair(primalCall, nullptr); } + auto calleeType = as<IRFuncType>(diffCallee->getDataType()); + SLANG_ASSERT(calleeType); + SLANG_RELEASE_ASSERT(calleeType->getParamCount() == origCall->getArgCount()); + List<IRInst*> args; // Go over the parameter list and create pairs for each input (if required) for (UIndex ii = 0; ii < origCall->getArgCount(); ii++) @@ -707,7 +800,15 @@ InstPair ForwardDerivativeTranscriber::transcribeCall(IRBuilder* builder, IRCall auto primalArg = findOrTranscribePrimalInst(builder, origArg); SLANG_ASSERT(primalArg); - auto primalType = primalArg->getDataType(); + auto primalType = primalArg->getDataType(); + auto paramType = calleeType->getParamType(ii); + if (!isNoDiffType(paramType)) + { + if (isNoDiffType(primalType)) + { + while (auto attrType = as<IRAttributedType>(primalType)) + primalType = attrType->getBaseType(); + } if (auto pairType = tryGetDiffPairType(builder, primalType)) { auto diffArg = findOrTranscribeDiffInst(builder, origArg); @@ -718,16 +819,16 @@ InstPair ForwardDerivativeTranscriber::transcribeCall(IRBuilder* builder, IRCall SLANG_RELEASE_ASSERT(diffArg); auto diffPair = builder->emitMakeDifferentialPair(pairType, primalArg, diffArg); args.add(diffPair); - } - else - { - // Add original/primal argument. - args.add(primalArg); + continue; } } + // Argument is not differentiable. + // Add original/primal argument. + args.add(primalArg); + } - IRType* diffReturnType = nullptr; - diffReturnType = tryGetDiffPairType(builder, origCall->getFullType()); + IRType* diffReturnType = nullptr; + diffReturnType = tryGetDiffPairType(builder, origCall->getFullType()); if (!diffReturnType) { @@ -942,37 +1043,37 @@ InstPair ForwardDerivativeTranscriber::transcribeSpecialize(IRBuilder* builder, // Make sure this isn't itself a specialize . SLANG_RELEASE_ASSERT(!as<IRSpecialize>(jvpFunc)); - return InstPair(primalSpecialize, jvpFunc); - } - else if (auto derivativeDecoration = genericInnerVal->findDecoration<IRForwardDerivativeDecoration>()) - { - diffBase = derivativeDecoration->getForwardDerivativeFunc(); - List<IRInst*> args; - for (UInt i = 0; i < primalSpecialize->getArgCount(); i++) - { - args.add(primalSpecialize->getArg(i)); - } - auto diffSpecialize = builder->emitSpecializeInst( - builder->getTypeKind(), diffBase, args.getCount(), args.getBuffer()); - return InstPair(primalSpecialize, diffSpecialize); - } - else if (auto diffDecor = genericInnerVal->findDecoration<IRForwardDifferentiableDecoration>()) + return InstPair(primalSpecialize, jvpFunc); + } + else if (auto derivativeDecoration = genericInnerVal->findDecoration<IRForwardDerivativeDecoration>()) + { + diffBase = derivativeDecoration->getForwardDerivativeFunc(); + List<IRInst*> args; + for (UInt i = 0; i < primalSpecialize->getArgCount(); i++) { - List<IRInst*> args; - for (UInt i = 0; i < primalSpecialize->getArgCount(); i++) - { - args.add(primalSpecialize->getArg(i)); - } - diffBase = findOrTranscribeDiffInst(builder, origSpecialize->getBase()); - auto diffSpecialize = builder->emitSpecializeInst( - builder->getTypeKind(), diffBase, args.getCount(), args.getBuffer()); - return InstPair(primalSpecialize, diffSpecialize); + args.add(primalSpecialize->getArg(i)); } - else + auto diffSpecialize = builder->emitSpecializeInst( + builder->getTypeKind(), diffBase, args.getCount(), args.getBuffer()); + return InstPair(primalSpecialize, diffSpecialize); + } + else if (auto diffDecor = genericInnerVal->findDecoration<IRForwardDifferentiableDecoration>()) + { + List<IRInst*> args; + for (UInt i = 0; i < primalSpecialize->getArgCount(); i++) { - return InstPair(primalSpecialize, nullptr); + args.add(primalSpecialize->getArg(i)); } + diffBase = findOrTranscribeDiffInst(builder, origSpecialize->getBase()); + auto diffSpecialize = builder->emitSpecializeInst( + builder->getTypeKind(), diffBase, args.getCount(), args.getBuffer()); + return InstPair(primalSpecialize, diffSpecialize); + } + else + { + return InstPair(primalSpecialize, nullptr); } +} InstPair ForwardDerivativeTranscriber::transcribeLookupInterfaceMethod(IRBuilder* builder, IRLookupWitnessMethod* lookupInst) { @@ -981,7 +1082,7 @@ InstPair ForwardDerivativeTranscriber::transcribeLookupInterfaceMethod(IRBuilder auto primalType = findOrTranscribePrimalInst(builder, lookupInst->getFullType()); auto primal = (IRSpecialize*)builder->emitLookupInterfaceMethodInst((IRType*)primalType, primalWt, primalKey); - auto interfaceType = as<IRInterfaceType>(as<IRWitnessTableTypeBase>(lookupInst->getWitnessTable()->getDataType())->getConformanceType()); + auto interfaceType = as<IRInterfaceType>(_unwrapAttributedType(as<IRWitnessTableTypeBase>(lookupInst->getWitnessTable()->getDataType())->getConformanceType())); if (!interfaceType) { return InstPair(primal, nullptr); @@ -1031,7 +1132,17 @@ IRInst* ForwardDerivativeTranscriber::getDifferentialZeroOfType(IRBuilder* build // Since primalType has a corresponding differential type, we can lookup the // definition for zero(). auto zeroMethod = differentiableTypeConformanceContext.getZeroMethodForType(builder, primalType); - SLANG_ASSERT(zeroMethod); + if (!zeroMethod) + { + // if the differential type itself comes from a witness lookup, we can just lookup the + // zero method from the same witness table. + if (auto lookupInterface = as<IRLookupWitnessMethod>(diffType)) + { + auto wt = lookupInterface->getWitnessTable(); + zeroMethod = builder->emitLookupInterfaceMethodInst(builder->getFuncType(List<IRType*>(), diffType), wt, autoDiffSharedContext->zeroMethodStructKey); + } + } + SLANG_RELEASE_ASSERT(zeroMethod); auto emptyArgList = List<IRInst*>(); return builder->emitCallInst((IRType*)diffType, zeroMethod, emptyArgList); diff --git a/source/slang/slang-ir-autodiff-fwd.h b/source/slang/slang-ir-autodiff-fwd.h index ab5d753d6..678677625 100644 --- a/source/slang/slang-ir-autodiff-fwd.h +++ b/source/slang/slang-ir-autodiff-fwd.h @@ -106,6 +106,8 @@ struct ForwardDerivativeTranscriber IRType* _differentiateTypeImpl(IRBuilder* builder, IRType* origType); + IRType* differentiateExtractExistentialType(IRBuilder* builder, IRExtractExistentialType* origType, IRInst*& witnessTable); + IRType* tryGetDiffPairType(IRBuilder* builder, IRType* primalType); InstPair transcribeParam(IRBuilder* builder, IRParam* origParam); diff --git a/source/slang/slang-ir-autodiff-pairs.cpp b/source/slang/slang-ir-autodiff-pairs.cpp index b9b4a8b66..dc72ed44a 100644 --- a/source/slang/slang-ir-autodiff-pairs.cpp +++ b/source/slang/slang-ir-autodiff-pairs.cpp @@ -1,4 +1,5 @@ #include "slang-ir-autodiff-pairs.h" +#include "slang-ir-hoist-local-types.h" namespace Slang { @@ -13,25 +14,22 @@ struct DiffPairLoweringPass : InstPassBase pairBuilder = &pairBuilderStorage; } - IRInst* lowerPairType(IRBuilder* builder, IRType* pairType, bool* isTrivial = nullptr) + IRInst* lowerPairType(IRBuilder* builder, IRType* pairType) { builder->setInsertBefore(pairType); - auto loweredPairTypeInfo = pairBuilder->lowerDiffPairType( + auto loweredPairType = pairBuilder->lowerDiffPairType( builder, pairType); - if (isTrivial) - *isTrivial = loweredPairTypeInfo.isTrivial; - return loweredPairTypeInfo.loweredType; + return loweredPairType; } IRInst* lowerMakePair(IRBuilder* builder, IRInst* inst) { - if (auto makePairInst = as<IRMakeDifferentialPair>(inst)) { bool isTrivial = false; auto pairType = as<IRDifferentialPairType>(makePairInst->getDataType()); - if (auto loweredPairType = lowerPairType(builder, pairType, &isTrivial)) + if (auto loweredPairType = lowerPairType(builder, pairType)) { builder->setInsertBefore(makePairInst); IRInst* result = nullptr; @@ -63,7 +61,7 @@ struct DiffPairLoweringPass : InstPassBase pairType = pairPtrType->getValueType(); } - if (lowerPairType(builder, pairType, nullptr)) + if (lowerPairType(builder, pairType)) { builder->setInsertBefore(getDiffInst); IRInst* diffFieldExtract = nullptr; @@ -81,7 +79,7 @@ struct DiffPairLoweringPass : InstPassBase pairType = pairPtrType->getValueType(); } - if (lowerPairType(builder, pairType, nullptr)) + if (lowerPairType(builder, pairType)) { builder->setInsertBefore(getPrimalInst); @@ -99,27 +97,9 @@ struct DiffPairLoweringPass : InstPassBase bool processInstWithChildren(IRBuilder* builder, IRInst* instWithChildren) { bool modified = false; + // Hoist all pair types to global scope when possible. - auto moduleInst = module->getModuleInst(); - processInstsOfType<IRDifferentialPairType>(kIROp_DifferentialPairType, [&](IRInst* originalPairType) - { - if (originalPairType->parent != moduleInst) - { - originalPairType->removeFromParent(); - ShortList<IRInst*> operands; - for (UInt i = 0; i < originalPairType->getOperandCount(); i++) - { - operands.add(originalPairType->getOperand(i)); - } - auto newPairType = builder->findOrEmitHoistableInst( - originalPairType->getFullType(), - originalPairType->getOp(), - originalPairType->getOperandCount(), - operands.getArrayView().getBuffer()); - originalPairType->replaceUsesWith(newPairType); - originalPairType->removeAndDeallocate(); - } - }); + hoistLocalTypes(module); autodiffContext->sharedBuilder->deduplicateAndRebuildGlobalNumberingMap(); diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index 5b5832073..86429f9ba 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -27,6 +27,20 @@ IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementK return nullptr; } +bool isNoDiffType(IRType* paramType) +{ + while (auto ptrType = as<IRPtrTypeBase>(paramType)) + paramType = ptrType->getValueType(); + while (auto attrType = as<IRAttributedType>(paramType)) + { + if (attrType->findAttr<IRNoDiffAttr>()) + { + return true; + } + } + return false; +} + IRStructField* DifferentialPairTypeBuilder::findField(IRInst* type, IRStructKey* key) { if (auto irStructType = as<IRStructType>(type)) @@ -80,19 +94,14 @@ IRInst* DifferentialPairTypeBuilder::emitFieldAccessor(IRBuilder* builder, IRIns IRInst* pairType = nullptr; if (auto basePtrType = as<IRPtrTypeBase>(baseInst->getDataType())) { - auto baseTypeInfo = lowerDiffPairType(builder, basePtrType->getValueType()); + auto loweredType = lowerDiffPairType(builder, basePtrType->getValueType()); - // TODO(sai): Not sure at the moment how to handle diff-bottom pointer types, - // especially since we probably don't need diff bottom anymore. - // - SLANG_ASSERT(!baseTypeInfo.isTrivial); - - pairType = builder->getPtrType(kIROp_PtrType, (IRType*)baseTypeInfo.loweredType); + pairType = builder->getPtrType(kIROp_PtrType, (IRType*)loweredType); } else { - auto baseTypeInfo = lowerDiffPairType(builder, baseInst->getDataType()); - pairType = baseTypeInfo.loweredType; + auto loweredType = lowerDiffPairType(builder, baseInst->getDataType()); + pairType = loweredType; } if (auto basePairStructType = as<IRStructType>(pairType)) @@ -240,33 +249,29 @@ IRInst* DifferentialPairTypeBuilder::getDiffTypeWitnessFromPairType(IRBuilder* b return _lookupWitness(builder, witnessTable, sharedContext->differentialAssocTypeWitnessStructKey); } -DifferentialPairTypeBuilder::LoweredPairTypeInfo DifferentialPairTypeBuilder::lowerDiffPairType( +IRInst* DifferentialPairTypeBuilder::lowerDiffPairType( IRBuilder* builder, IRType* originalPairType) { - LoweredPairTypeInfo result = {}; - + IRInst* result = nullptr; if (pairTypeCache.TryGetValue(originalPairType, result)) return result; auto pairType = as<IRDifferentialPairType>(originalPairType); if (!pairType) { - result.isTrivial = true; - result.loweredType = originalPairType; + result = originalPairType; return result; } auto primalType = pairType->getValueType(); if (as<IRParam>(primalType)) { - result.isTrivial = false; - result.loweredType = nullptr; + result = nullptr; return result; } auto diffType = getDiffTypeFromPairType(builder, pairType); if (!diffType) return result; - result.loweredType = _createDiffPairType(pairType->getValueType(), (IRType*)diffType); - result.isTrivial = false; + result = _createDiffPairType(pairType->getValueType(), (IRType*)diffType); pairTypeCache.Add(originalPairType, result); return result; @@ -469,6 +474,22 @@ bool processAutodiffCalls( // Process reverse derivative calls. modified |= processReverseDerivativeCalls(&autodiffContext, sink); + return modified; +} + +bool finalizeAutoDiffPass(IRModule* module) +{ + bool modified = false; + + // Create shared context for all auto-diff related passes + AutoDiffSharedContext autodiffContext(module->getModuleInst()); + + SharedIRBuilder sharedBuilder; + sharedBuilder.init(module); + sharedBuilder.deduplicateAndRebuildGlobalNumberingMap(); + + autodiffContext.sharedBuilder = &sharedBuilder; + // Replaces IRDifferentialPairType with an auto-generated struct, // IRDifferentialPairGetDifferential with 'differential' field access, // IRDifferentialPairGetPrimal with 'primal' field access, and @@ -481,7 +502,7 @@ bool processAutodiffCalls( // Remove auto-diff related decorations. stripAutoDiffDecorations(module); - return modified; + return false; } diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h index e470044a4..25cbe16f4 100644 --- a/source/slang/slang-ir-autodiff.h +++ b/source/slang/slang-ir-autodiff.h @@ -147,12 +147,6 @@ struct DifferentiableTypeConformanceContext struct DifferentialPairTypeBuilder { - struct LoweredPairTypeInfo - { - IRInst* loweredType; - bool isTrivial; - }; - DifferentialPairTypeBuilder() = default; DifferentialPairTypeBuilder(AutoDiffSharedContext* sharedContext) : sharedContext(sharedContext) {} @@ -177,10 +171,16 @@ struct DifferentialPairTypeBuilder IRInst* getDiffTypeWitnessFromPairType(IRBuilder* builder, IRDifferentialPairType* type); - LoweredPairTypeInfo lowerDiffPairType(IRBuilder* builder, IRType* originalPairType); + IRInst* lowerDiffPairType(IRBuilder* builder, IRType* originalPairType); + struct PairStructKey + { + IRInst* originalType; + IRInst* diffType; + }; - Dictionary<IRInst*, LoweredPairTypeInfo> pairTypeCache; + // Cache from `IRDifferentialPairType` to materialized struct type. + Dictionary<IRInst*, IRInst*> pairTypeCache; IRStructKey* globalPrimalKey = nullptr; @@ -197,6 +197,8 @@ void stripAutoDiffDecorations(IRModule* module); IRInst* _lookupWitness(IRBuilder* builder, IRInst* witness, IRInst* requirementKey); +bool isNoDiffType(IRType* paramType); + struct IRAutodiffPassOptions { // Nothing for now... @@ -207,4 +209,6 @@ bool processAutodiffCalls( DiagnosticSink* sink, IRAutodiffPassOptions const& options = IRAutodiffPassOptions()); -};
\ No newline at end of file +bool finalizeAutoDiffPass(IRModule* module); + +}; diff --git a/source/slang/slang-ir-hoist-local-types.cpp b/source/slang/slang-ir-hoist-local-types.cpp index 756a25c49..cf091f701 100644 --- a/source/slang/slang-ir-hoist-local-types.cpp +++ b/source/slang/slang-ir-hoist-local-types.cpp @@ -8,7 +8,6 @@ namespace Slang struct HoistLocalTypesContext { IRModule* module; - DiagnosticSink* sink; SharedIRBuilder sharedBuilderStorage; @@ -98,11 +97,10 @@ struct HoistLocalTypesContext } }; -void hoistLocalTypes(IRModule* module, DiagnosticSink* sink) +void hoistLocalTypes(IRModule* module) { HoistLocalTypesContext context; context.module = module; - context.sink = sink; context.processModule(); } diff --git a/source/slang/slang-ir-hoist-local-types.h b/source/slang/slang-ir-hoist-local-types.h index 6b742746f..55e62ce57 100644 --- a/source/slang/slang-ir-hoist-local-types.h +++ b/source/slang/slang-ir-hoist-local-types.h @@ -13,6 +13,6 @@ class DiagnosticSink; /// can be hoisted to global scope. This pass examines all local type defintions // and try to hoist them to global scope if the definition is no longer dependent on // the local context. -void hoistLocalTypes(IRModule* module, DiagnosticSink* sink); +void hoistLocalTypes(IRModule* module); } // namespace Slang diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 28639ae53..f836824f7 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -2884,6 +2884,11 @@ void collectParameterLists( auto thisType = getThisParamTypeForContainer(context, parentDeclRef); if(thisType) { + if (declRef.getDecl()->findModifier<NoDiffThisAttribute>()) + { + auto noDiffAttr = context->astBuilder->getNoDiffModifierVal(); + thisType = context->astBuilder->getModifiedType(thisType, 1, &noDiffAttr); + } addThisParameter(innerThisParamDirection, thisType, ioParameterLists); } } diff --git a/source/slang/slang-syntax.cpp b/source/slang/slang-syntax.cpp index a79c48227..4e5db17c0 100644 --- a/source/slang/slang-syntax.cpp +++ b/source/slang/slang-syntax.cpp @@ -1368,6 +1368,14 @@ Module* getModule(Decl* decl) return moduleDecl->module; } +Decl* getParentDecl(Decl* decl) +{ + decl = decl->parentDecl; + while (as<GenericDecl>(decl)) + decl = decl->parentDecl; + return decl; +} + static const ImageFormatInfo kImageFormatInfos[] = { #define SLANG_IMAGE_FORMAT_INFO(TYPE, COUNT, SIZE) SLANG_SCALAR_TYPE_##TYPE, uint8_t(COUNT), uint8_t(SIZE) diff --git a/source/slang/slang-syntax.h b/source/slang/slang-syntax.h index 441dcb8e7..e36ee944c 100644 --- a/source/slang/slang-syntax.h +++ b/source/slang/slang-syntax.h @@ -329,10 +329,11 @@ namespace Slang /// Get the module dclaration that a declaration is associated with, if any. ModuleDecl* getModuleDecl(Decl* decl); - /// Get the module that a declaration is associated with, if any. + /// Get the module that a declaration is associated with, if any. Module* getModule(Decl* decl); - + /// Get the parent decl, skipping any generic decls in between. + Decl* getParentDecl(Decl* decl); } // namespace Slang |
