From 85c1569308793cc2408088e539a3ed1da5f9d235 Mon Sep 17 00:00:00 2001 From: Yong He Date: Fri, 24 Feb 2023 14:33:32 -0800 Subject: Support dynamic dispatch a backward differentiable function. (#2678) Co-authored-by: Yong He --- source/slang/slang-ast-decl.h | 15 ------- source/slang/slang-check-decl.cpp | 69 +---------------------------- source/slang/slang-ir-autodiff-transpose.h | 56 ++++++++++++++++++------ source/slang/slang-ir-autodiff-unzip.h | 70 +++++++++++++++++++++--------- source/slang/slang-lower-to-ir.cpp | 14 ------ source/slang/slang-mangle.cpp | 6 --- 6 files changed, 93 insertions(+), 137 deletions(-) (limited to 'source') diff --git a/source/slang/slang-ast-decl.h b/source/slang/slang-ast-decl.h index 81a6e3f7d..ccbac0286 100644 --- a/source/slang/slang-ast-decl.h +++ b/source/slang/slang-ast-decl.h @@ -546,21 +546,6 @@ class BackwardDerivativeRequirementDecl : public DerivativeRequirementDecl SLANG_AST_CLASS(BackwardDerivativeRequirementDecl) }; -class BackwardDerivativePrimalRequirementDecl : public DerivativeRequirementDecl -{ - SLANG_AST_CLASS(BackwardDerivativePrimalRequirementDecl) -}; - -class BackwardDerivativePropagateRequirementDecl : public DerivativeRequirementDecl -{ - SLANG_AST_CLASS(BackwardDerivativePropagateRequirementDecl) -}; - -class BackwardDerivativeIntermediateTypeRequirementDecl : public DerivativeRequirementDecl -{ - SLANG_AST_CLASS(BackwardDerivativeIntermediateTypeRequirementDecl) -}; - bool isInterfaceRequirement(Decl* decl); InterfaceDecl* findParentInterfaceDecl(Decl* decl); diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 142842e12..a1d5acfb0 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -2677,24 +2677,6 @@ namespace Slang val->func = satisfyingMemberDeclRef; witnessTable->add(bwdReq, RequirementWitness(val)); } - else if (auto primalReq = as(reqRefDecl->referencedDecl)) - { - DifferentiateVal* val = m_astBuilder->create(); - val->func = satisfyingMemberDeclRef; - witnessTable->add(primalReq, RequirementWitness(val)); - } - else if (auto propReq = as(reqRefDecl->referencedDecl)) - { - DifferentiateVal* val = m_astBuilder->create(); - val->func = satisfyingMemberDeclRef; - witnessTable->add(propReq, RequirementWitness(val)); - } - else if (auto itypeReq = as(reqRefDecl->referencedDecl)) - { - DifferentiateVal* val = m_astBuilder->create(); - val->func = satisfyingMemberDeclRef; - witnessTable->add(itypeReq, RequirementWitness(val)); - } } witnessTable->add(requiredMemberDeclRef, RequirementWitness(satisfyingMemberDeclRef)); } @@ -5920,7 +5902,7 @@ namespace Slang if (auto interfaceDecl = findParentInterfaceDecl(decl)) { bool isDiffFunc = false; - if (decl->hasModifier()) + if (decl->hasModifier() || decl->hasModifier()) { auto reqDecl = m_astBuilder->create(); cloneModifiers(reqDecl, decl); @@ -5954,55 +5936,6 @@ namespace Slang reqRef->parentDecl = decl; decl->members.add(reqRef); } - // Requirement for backward derivative intermediate type. - auto intermediateTypeReqDecl = m_astBuilder->create(); - auto intermediateType = m_astBuilder->getOrCreateDeclRefType( - intermediateTypeReqDecl, createDefaultSubstitutions(m_astBuilder, this, decl)); - { - cloneModifiers(intermediateTypeReqDecl, decl); - interfaceDecl->members.add(intermediateTypeReqDecl); - intermediateTypeReqDecl->parentDecl = interfaceDecl; - - auto reqRef = m_astBuilder->create(); - reqRef->referencedDecl = intermediateTypeReqDecl; - reqRef->parentDecl = decl; - decl->members.add(reqRef); - } - // Requirement for backward derivative primal func. - { - auto reqDecl = m_astBuilder->create(); - cloneModifiers(reqDecl, decl); - FuncType* primalFuncType = m_astBuilder->create(); - primalFuncType->resultType = originalFuncType->resultType; - primalFuncType->paramTypes.addRange(originalFuncType->paramTypes); - auto outType = m_astBuilder->getOutType(intermediateType); - primalFuncType->paramTypes.add(outType); - setFuncTypeIntoRequirementDecl(reqDecl, primalFuncType); - interfaceDecl->members.add(reqDecl); - reqDecl->parentDecl = interfaceDecl; - - auto reqRef = m_astBuilder->create(); - reqRef->referencedDecl = reqDecl; - reqRef->parentDecl = decl; - decl->members.add(reqRef); - } - // Requirement for backward derivative propagate func. - { - auto reqDecl = m_astBuilder->create(); - cloneModifiers(reqDecl, decl); - interfaceDecl->members.add(reqDecl); - reqDecl->parentDecl = interfaceDecl; - FuncType* propagateFuncType = m_astBuilder->create(); - propagateFuncType->resultType = diffFuncType->resultType; - propagateFuncType->paramTypes.addRange(diffFuncType->paramTypes); - propagateFuncType->paramTypes.add(intermediateType); - setFuncTypeIntoRequirementDecl(reqDecl, propagateFuncType); - auto reqRef = m_astBuilder->create(); - reqRef->referencedDecl = reqDecl; - reqRef->parentDecl = decl; - decl->members.add(reqRef); - } - isDiffFunc = true; } if (isDiffFunc) diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h index a4c79d09a..e1832b9eb 100644 --- a/source/slang/slang-ir-autodiff-transpose.h +++ b/source/slang/slang-ir-autodiff-transpose.h @@ -851,6 +851,8 @@ struct DiffTransposePass if (as(child) || as(child)) continue; + if (as(child)) + continue; if (isDifferentialInst(child)) transposeInst(&builder, child); @@ -1332,10 +1334,6 @@ struct DiffTransposePass } } - // The call must have been decorated with the continuation context after splitting. - auto primalContextDecor = fwdCall->findDecoration(); - SLANG_RELEASE_ASSERT(primalContextDecor); - auto baseFn = fwdDiffCallee->getBaseFn(); List args; @@ -1453,20 +1451,52 @@ struct DiffTransposePass argRequiresLoad.add(false); } - // Ensure availability of the primal context var - auto primalContextVar = hoistPrimalInst(builder, primalContextDecor->getBackwardDerivativePrimalContextVar()); - SLANG_RELEASE_ASSERT(primalContextVar); + // If the callee provides a primal implementation that produces continuation context for propagation phase + // we grab it and pass it as argument to the propagation function. + if (auto primalContextDecor = fwdCall->findDecoration()) + { + // Ensure availability of the primal context var + auto primalContextVar = hoistPrimalInst(builder, primalContextDecor->getBackwardDerivativePrimalContextVar()); + SLANG_RELEASE_ASSERT(primalContextVar); - args.add(builder->emitLoad(primalContextVar)); - argTypes.add(as( + args.add(builder->emitLoad(primalContextVar)); + argTypes.add(as( primalContextVar->getDataType()) ->getValueType()); - argRequiresLoad.add(false); + argRequiresLoad.add(false); + } auto revFnType = builder->getFuncType(argTypes, builder->getVoidType()); - auto revCallee = builder->emitBackwardDifferentiatePropagateInst( - revFnType, - baseFn); + IRInst* revCallee = nullptr; + if (getResolvedInstForDecorations(baseFn)->getOp() == kIROp_LookupWitness) + { + // This is an interface method call, we can simply transcribe it here. + auto specialize = as(baseFn); + auto innerFn = baseFn; + if (specialize) + innerFn = specialize->getBase(); + auto lookupWitness = as(innerFn); + SLANG_RELEASE_ASSERT(lookupWitness); + auto diffDecor = lookupWitness->getRequirementKey()->findDecoration(); + SLANG_RELEASE_ASSERT(diffDecor); + auto diffKey = diffDecor->getBackwardDerivativeFunc(); + revCallee = builder->emitLookupInterfaceMethodInst(builder->getTypeKind(), lookupWitness->getWitnessTable(), diffKey); + if (specialize) + { + List specArgs; + for (UInt i = 0; i < specialize->getArgCount(); i++) + specArgs.add(specialize->getArg(i)); + revCallee = builder->emitSpecializeInst(builder->getTypeKind(), revCallee, specArgs.getCount(), specArgs.getBuffer()); + } + revCallee->setFullType(revFnType); + } + else + { + // All other calls, we insert a `backwardDifferentiate` inst so we will process it in a follow-up iteration. + revCallee = builder->emitBackwardDifferentiatePropagateInst( + revFnType, + baseFn); + } List callArgs; for (auto arg : args) diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h index 2ebc330f0..a30826370 100644 --- a/source/slang/slang-ir-autodiff-unzip.h +++ b/source/slang/slang-ir-autodiff-unzip.h @@ -830,28 +830,51 @@ struct DiffUnzipPass { auto func = findSpecializeReturnVal(specialize); auto outerGen = findOuterGeneric(func); - intermediateType = primalBuilder->getBackwardDiffIntermediateContextType(outerGen); - List args; - for (UInt i = 0; i < specialize->getArgCount(); i++) - args.add(specialize->getArg(i)); - intermediateType = primalBuilder->emitSpecializeInst( - primalBuilder->getTypeKind(), - intermediateType, - args.getCount(), - args.getBuffer()); + if (func->getOp() == kIROp_LookupWitness) + { + // An interface method won't have intermediate type. + intermediateType = primalBuilder->getVoidType(); + } + else + { + intermediateType = primalBuilder->getBackwardDiffIntermediateContextType(outerGen); + List args; + for (UInt i = 0; i < specialize->getArgCount(); i++) + args.add(specialize->getArg(i)); + intermediateType = primalBuilder->emitSpecializeInst( + primalBuilder->getTypeKind(), + intermediateType, + args.getCount(), + args.getBuffer()); + } } else { - intermediateType = primalBuilder->getBackwardDiffIntermediateContextType(baseFn); + if (baseFn->getOp() == kIROp_LookupWitness) + intermediateType = primalBuilder->getVoidType(); + else + intermediateType = primalBuilder->getBackwardDiffIntermediateContextType(baseFn); } - auto intermediateVar = primalBuilder->emitVar((IRType*)intermediateType); - primalBuilder->markInstAsPrimal(intermediateVar); + IRVar* intermediateVar = nullptr; + if (!as(intermediateType)) + { + intermediateVar = primalBuilder->emitVar((IRType*)intermediateType); + primalBuilder->markInstAsPrimal(intermediateVar); + } - primalBuilder->addBackwardDerivativePrimalContextDecoration(intermediateVar, intermediateVar); - - auto primalFn = primalBuilder->emitBackwardDifferentiatePrimalInst(primalFuncType, baseFn); - + IRInst* primalFn = nullptr; + if (intermediateVar) + { + primalBuilder->addBackwardDerivativePrimalContextDecoration(intermediateVar, intermediateVar); + primalFn = primalBuilder->emitBackwardDifferentiatePrimalInst(primalFuncType, baseFn); + } + else + { + // If we decided not to use diff-primal func that stores an reuse context, + // we can just call the original function instead. + primalFn = baseFn; + } List primalArgs; for (UIndex ii = 0; ii < mixedCall->getArgCount(); ii++) { @@ -865,7 +888,8 @@ struct DiffUnzipPass primalArgs.add(arg); } } - primalArgs.add(intermediateVar); + if (intermediateType->getOp() != kIROp_VoidType) + primalArgs.add(intermediateVar); auto mixedDecoration = mixedCall->findDecoration(); SLANG_ASSERT(mixedDecoration); @@ -881,7 +905,8 @@ struct DiffUnzipPass } auto primalVal = primalBuilder->emitCallInst(primalType, primalFn, primalArgs); - primalBuilder->addBackwardDerivativePrimalContextDecoration(primalVal, intermediateVar); + if (intermediateVar) + primalBuilder->addBackwardDerivativePrimalContextDecoration(primalVal, intermediateVar); primalBuilder->markInstAsPrimal(primalVal); SLANG_RELEASE_ASSERT(mixedCall->getArgCount() <= primalFuncType->getParamCount()); @@ -960,9 +985,12 @@ struct DiffUnzipPass diffArgs); diffBuilder->markInstAsDifferential(callInst, primalType); - disableIRValidationAtInsert(); - diffBuilder->addBackwardDerivativePrimalContextDecoration(callInst, intermediateVar); - enableIRValidationAtInsert(); + if (intermediateVar) + { + disableIRValidationAtInsert(); + diffBuilder->addBackwardDerivativePrimalContextDecoration(callInst, intermediateVar); + enableIRValidationAtInsert(); + } IRInst* diffVal = nullptr; if (as(callInst->getDataType())) diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index d09c35eea..261e08168 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -6899,14 +6899,6 @@ struct DeclLoweringVisitor : DeclVisitor { op = kIROp_BackwardDerivativeDecoration; } - else if (as(requirementDecl)) - { - op = kIROp_BackwardDerivativePropagateDecoration; - } - else if (as(requirementDecl)) - { - op = kIROp_BackwardDerivativePrimalDecoration; - } else if (as(requirementDecl)) { op = kIROp_ForwardDerivativeDecoration; @@ -8534,12 +8526,6 @@ struct DeclLoweringVisitor : DeclVisitor UNREACHABLE_RETURN(LoweredValInfo()); } - LoweredValInfo visitBackwardDerivativeIntermediateTypeRequirementDecl(BackwardDerivativeIntermediateTypeRequirementDecl* decl) - { - SLANG_UNUSED(decl); - return LoweredValInfo(getBuilder()->getTypeKind()); - } - LoweredValInfo visitFunctionDeclBase(FunctionDeclBase* decl) { // A function declaration may have multiple, target-specific diff --git a/source/slang/slang-mangle.cpp b/source/slang/slang-mangle.cpp index da5099934..a7d047a0c 100644 --- a/source/slang/slang-mangle.cpp +++ b/source/slang/slang-mangle.cpp @@ -521,12 +521,6 @@ namespace Slang emitRaw(context, "FwdReq_"); else if (as(decl)) emitRaw(context, "BwdReq_"); - else if (as(decl)) - emitRaw(context, "BwdReq_Prop_"); - else if (as(decl)) - emitRaw(context, "BwdReq_Primal_"); - else if (as(decl)) - emitRaw(context, "BwdReq_CtxType_"); else { // TODO: handle other cases -- cgit v1.2.3