From 56a84a06488afb817f79fbd99e8b470bd587ccd1 Mon Sep 17 00:00:00 2001 From: Yong He Date: Thu, 23 Mar 2023 22:42:59 -0700 Subject: Fix various autodiff crashes related to interface usage. (#2730) * Fix crash. * Fix `[ForwradDerivative]` on member functions. * Update comments. * Fix crash when [BackwardDerivative] is provided but not [ForwardDerivative]. * Allow calling dynamic dispatched generic method from differentiable func. * Fix. --------- Co-authored-by: Yong He --- source/slang/slang-check-decl.cpp | 4 +++ source/slang/slang-diagnostic-defs.h | 2 ++ source/slang/slang-ir-addr-inst-elimination.cpp | 21 +++--------- source/slang/slang-ir-autodiff-fwd.cpp | 40 +++++++++++++++------- .../slang/slang-ir-autodiff-transcriber-base.cpp | 8 ++++- source/slang/slang-ir-autodiff.cpp | 1 + source/slang/slang-lower-to-ir.cpp | 11 ++++-- 7 files changed, 54 insertions(+), 33 deletions(-) (limited to 'source') diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 6083ce9c0..6a32f59d3 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -5638,6 +5638,10 @@ namespace Slang bool isDiffFunc = false; if (decl->hasModifier() || decl->hasModifier()) { + if (GetOuterGeneric(decl)) + { + getSink()->diagnose(decl, Diagnostics::differentiableGenericInterfaceMethodNotSupported); + } auto reqDecl = m_astBuilder->create(); cloneModifiers(reqDecl, decl); auto declRef = DeclRef(decl, createDefaultSubstitutions(m_astBuilder, this, decl)); diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index 2401b6e58..e3e9cfc44 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -352,6 +352,8 @@ DIAGNOSTIC(31145, Error, invalidCustomDerivative, "invalid custom derivative att DIAGNOSTIC(31146, Error, declAlreadyHasAttribute, "'$0' already has attribute '[$1]'.") DIAGNOSTIC(31147, Error, cannotResolveOriginalFunctionForDerivative, "cannot resolve the original function for the the custom derivative.") +DIAGNOSTIC(31148, Error, differentiableGenericInterfaceMethodNotSupported, "`[ForwardDifferentiable] and [BackwardDifferentiable] are not supported on generic interface requirements.") + // Enums DIAGNOSTIC(32000, Error, invalidEnumTagType, "invalid tag type for 'enum': '$0'") diff --git a/source/slang/slang-ir-addr-inst-elimination.cpp b/source/slang/slang-ir-addr-inst-elimination.cpp index 6715f2c6a..16bd67f66 100644 --- a/source/slang/slang-ir-addr-inst-elimination.cpp +++ b/source/slang/slang-ir-addr-inst-elimination.cpp @@ -99,22 +99,11 @@ struct AddressInstEliminationContext IRBuilder builder(module); builder.setInsertBefore(call); auto tempVar = builder.emitVar(cast(addr->getFullType())->getValueType()); - auto callee = getResolvedInstForDecorations(call->getCallee()); - auto funcType = as(callee->getFullType()); - SLANG_RELEASE_ASSERT(funcType); - UInt paramIndex = (UInt)(use - call->getOperands() - 1); - SLANG_RELEASE_ASSERT(call->getArg(paramIndex) == addr); - if (!as(funcType->getParamType(paramIndex))) - { - builder.emitStore(tempVar, getValue(builder, addr)); - } - else - { - builder.emitStore( - tempVar, - builder.emitDefaultConstruct( - as(tempVar->getDataType())->getValueType())); - } + + // Store the initial value of the mutable argument into temp var. + // If this is an `out` var, the initial value will be undefined, + // which will get cleaned up later into a `defaultConstruct`. + builder.emitStore(tempVar, getValue(builder, addr)); builder.setInsertAfter(call); storeValue(builder, addr, builder.emitLoad(tempVar)); use->set(tempVar); diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index 3f31f1463..869f8920c 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -510,7 +510,7 @@ IRInst* tryFindPrimalSubstitute(IRBuilder* builder, IRInst* callee) { auto innerGen = as(specialize->getBase()); if (!innerGen) - return nullptr; + return callee; auto innerFunc = findGenericReturnVal(innerGen); if (auto decor = innerFunc->findDecoration()) { @@ -553,7 +553,7 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig return InstPair(nullptr, nullptr); } - auto primalCallee = lookupPrimalInst(builder, origCallee, origCallee); + auto primalCallee = findOrTranscribePrimalInst(builder, origCallee); auto substPrimalCallee = tryFindPrimalSubstitute(builder, primalCallee); IRInst* diffCallee = nullptr; @@ -563,7 +563,7 @@ InstPair ForwardDiffTranscriber::transcribeCall(IRBuilder* builder, IRCall* orig } else { - instMapD.TryGetValue(substPrimalCallee, diffCallee); + diffCallee = findOrTranscribeDiffInst(builder, origCallee); primalCallee = substPrimalCallee; } @@ -904,17 +904,32 @@ InstPair ForwardDiffTranscriber::transcribeSpecialize(IRBuilder* builder, IRSpec IRInst* diffBase = nullptr; if (instMapD.TryGetValue(origSpecialize->getBase(), diffBase)) { - List args; - for (UInt i = 0; i < primalSpecialize->getArgCount(); i++) + if (diffBase) { - args.add(primalSpecialize->getArg(i)); + List args; + for (UInt i = 0; i < primalSpecialize->getArgCount(); i++) + { + args.add(primalSpecialize->getArg(i)); + } + auto diffSpecialize = builder->emitSpecializeInst( + builder->getTypeKind(), diffBase, args.getCount(), args.getBuffer()); + return InstPair(primalSpecialize, diffSpecialize); + } + else + { + return InstPair(primalSpecialize, nullptr); } - auto diffSpecialize = builder->emitSpecializeInst( - builder->getTypeKind(), diffBase, args.getCount(), args.getBuffer()); - return InstPair(primalSpecialize, diffSpecialize); } auto genericInnerVal = findInnerMostGenericReturnVal(as(origSpecialize->getBase())); + + // Right now we don't support transcribing a differentiable callee that is a specialize of a interface lookup + // (calling differentiable generic interface method). To support it, we need to recursively transcribe the + // specialization base here. + + if (!genericInnerVal) + return InstPair(primalSpecialize, nullptr); + // Look for an IRForwardDerivativeDecoration on the specialize inst. // (Normally, this would be on the inner IRFunc, but in this case only the JVP func // can be specialized, so we put a decoration on the IRSpecialize) @@ -963,10 +978,7 @@ InstPair ForwardDiffTranscriber::transcribeSpecialize(IRBuilder* builder, IRSpec builder->getTypeKind(), diffBase, args.getCount(), args.getBuffer()); return InstPair(primalSpecialize, diffSpecialize); } - else - { - return InstPair(primalSpecialize, nullptr); - } + return InstPair(primalSpecialize, nullptr); } InstPair ForwardDiffTranscriber::transcribeFieldExtract(IRBuilder* builder, IRInst* originalInst) @@ -1433,6 +1445,8 @@ IRFunc* ForwardDiffTranscriber::transcribeFuncHeaderImpl(IRBuilder* inBuilder, I IRFunc* primalFunc = origFunc; + maybeMigrateDifferentiableDictionaryFromDerivativeFunc(inBuilder, origFunc); + differentiableTypeConformanceContext.setFunc(origFunc); primalFunc = origFunc; diff --git a/source/slang/slang-ir-autodiff-transcriber-base.cpp b/source/slang/slang-ir-autodiff-transcriber-base.cpp index 552ac762c..9cbea7873 100644 --- a/source/slang/slang-ir-autodiff-transcriber-base.cpp +++ b/source/slang/slang-ir-autodiff-transcriber-base.cpp @@ -594,7 +594,11 @@ void AutoDiffTranscriberBase::maybeMigrateDifferentiableDictionaryFromDerivative } else { - cloneDecoration(udfDecor, origFunc); + auto udfDictDecor = derivative->findDecoration< IRDifferentiableTypeDictionaryDecoration>(); + if (udfDictDecor) + { + cloneDecoration(udfDictDecor, origFunc); + } } } @@ -977,6 +981,8 @@ InstPair AutoDiffTranscriberBase::transcribeGeneric(IRBuilder* inBuilder, IRGene if (auto innerFunc = as(innerVal)) { maybeMigrateDifferentiableDictionaryFromDerivativeFunc(inBuilder, innerFunc); + if (!innerFunc->findDecoration()) + return InstPair(origGeneric, nullptr); differentiableTypeConformanceContext.setFunc(innerFunc); } else if (auto funcType = as(innerVal)) diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index f173aaa8b..1909f860c 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -368,6 +368,7 @@ void DifferentiableTypeConformanceContext::setFunc(IRGlobalValueWithCode* func) { parentFunc = func; + auto decor = func->findDecoration(); SLANG_RELEASE_ASSERT(decor); diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index f84f17886..9c27beb58 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -8484,10 +8484,15 @@ struct DeclLoweringVisitor : DeclVisitor funcExpr = udAttr->funcExpr; else if (auto primalAttr = as(modifier)) funcExpr = primalAttr->funcExpr; + DeclRefExpr* declRefExpr = as(funcExpr); + auto funcType = lowerType(subContext, funcExpr->type); + auto loweredVal = emitDeclRef( + subContext, + declRefExpr->declRef, + funcType); + + SLANG_RELEASE_ASSERT(loweredVal.flavor == LoweredValInfo::Flavor::Simple); - auto loweredVal = lowerRValueExpr(subContext, funcExpr); - - SLANG_ASSERT(loweredVal.flavor == LoweredValInfo::Flavor::Simple); IRInst* derivativeFunc = loweredVal.val; if (as(modifier)) -- cgit v1.2.3