From fc54adee1f7f0ba18591fc84ce5d51ac23afa954 Mon Sep 17 00:00:00 2001 From: Yong He Date: Wed, 26 Apr 2023 17:37:04 -0700 Subject: Autodiff support for dynamically dispatched generic method. (#2846) * Autodiff support for dynamically dispatched generic method. * Fix. * Support dynamically dispatched generic type. --------- Co-authored-by: Yong He --- source/slang/slang-ast-decl.h | 3 +++ source/slang/slang-check-decl.cpp | 8 +++----- source/slang/slang-diagnostic-defs.h | 2 -- source/slang/slang-ir-autodiff-fwd.cpp | 2 +- source/slang/slang-ir-autodiff-rev.cpp | 2 +- source/slang/slang-ir-autodiff-transpose.h | 13 +++---------- source/slang/slang-ir-autodiff-unzip.h | 15 +++++++++------ source/slang/slang-ir-autodiff.cpp | 12 ++++++++++++ source/slang/slang-ir-autodiff.h | 4 ++++ source/slang/slang-ir-lower-witness-lookup.cpp | 1 + source/slang/slang-ir.cpp | 3 ++- source/slang/slang-lower-to-ir.cpp | 14 ++++++++++++-- 12 files changed, 51 insertions(+), 28 deletions(-) (limited to 'source') diff --git a/source/slang/slang-ast-decl.h b/source/slang/slang-ast-decl.h index ccbac0286..e75660c7b 100644 --- a/source/slang/slang-ast-decl.h +++ b/source/slang/slang-ast-decl.h @@ -526,6 +526,9 @@ class AttributeDecl : public ContainerDecl class DerivativeRequirementDecl : public FunctionDeclBase { SLANG_AST_CLASS(DerivativeRequirementDecl) + + // The original requirement decl. + Decl* originalRequirementDecl = nullptr; }; // A reference to a synthesized decl representing a differentiable function requirement, this decl will diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 0901d2026..b3470e882 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -1427,6 +1427,7 @@ namespace Slang varDecl->initExpr = CompleteOverloadCandidate(overloadContext, *overloadContext.bestCandidate); } } + maybeRegisterDifferentiableType(getASTBuilder(), varDecl->getType()); } // Fill in default substitutions for the 'subtype' part of a type constraint decl @@ -4738,7 +4739,6 @@ namespace Slang void SemanticsDeclBodyVisitor::visitFunctionDeclBase(FunctionDeclBase* decl) { auto newContext = withParentFunc(decl); - if (newContext.getParentDifferentiableAttribute()) { // Register additional types outside the function body first. @@ -5638,11 +5638,8 @@ namespace Slang bool isDiffFunc = false; if (decl->hasModifier() || decl->hasModifier()) { - if (GetOuterGeneric(decl)) - { - getSink()->diagnose(decl, Diagnostics::differentiableGenericInterfaceMethodNotSupported); - } auto reqDecl = m_astBuilder->create(); + reqDecl->originalRequirementDecl = decl; cloneModifiers(reqDecl, decl); auto declRef = DeclRef(decl, createDefaultSubstitutions(m_astBuilder, this, decl)); auto diffFuncType = getForwardDiffFuncType(getFuncType(m_astBuilder, declRef)); @@ -5664,6 +5661,7 @@ namespace Slang auto diffFuncType = as(getBackwardDiffFuncType(originalFuncType)); { auto reqDecl = m_astBuilder->create(); + reqDecl->originalRequirementDecl = decl; cloneModifiers(reqDecl, decl); setFuncTypeIntoRequirementDecl(reqDecl, diffFuncType); interfaceDecl->members.add(reqDecl); diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index ec8131824..cb441ade8 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -359,8 +359,6 @@ DIAGNOSTIC(31146, Error, declAlreadyHasAttribute, "'$0' already has attribute '[ DIAGNOSTIC(31147, Error, cannotResolveOriginalFunctionForDerivative, "cannot resolve the original function for the the custom derivative.") DIAGNOSTIC(31148, Error, cannotResolveDerivativeFunction, "cannot resolve the custom derivative function") -DIAGNOSTIC(31149, Error, differentiableGenericInterfaceMethodNotSupported, "`[ForwardDifferentiable] and [BackwardDifferentiable] are not supported on generic interface requirements.") - DIAGNOSTIC(31200, Warning, deprecatedUsage, "$0 has been deprecated: $1") // Enums diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index e0b916090..819c6bc57 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -955,7 +955,7 @@ InstPair ForwardDiffTranscriber::transcribeSpecialize(IRBuilder* builder, IRSpec builder->getTypeKind(), diffBaseSpecialize->getBase(), args.getCount(), args.getBuffer()); return InstPair(primalSpecialize, diffSpecialize); } - else if (_isDifferentiableFunc(genericInnerVal)) + else if (_isDifferentiableFunc(genericInnerVal) || as(genericInnerVal)) { List args; for (UInt i = 0; i < primalSpecialize->getArgCount(); i++) diff --git a/source/slang/slang-ir-autodiff-rev.cpp b/source/slang/slang-ir-autodiff-rev.cpp index 2994a8c31..e5735b831 100644 --- a/source/slang/slang-ir-autodiff-rev.cpp +++ b/source/slang/slang-ir-autodiff-rev.cpp @@ -1273,7 +1273,7 @@ namespace Slang return InstPair(primalSpecialize, diffSpecialize); } - else if (isBackwardDifferentiableFunc(genericInnerVal)) + else if (isBackwardDifferentiableFunc(genericInnerVal) || as(genericInnerVal)) { List args; for (UInt i = 0; i < primalSpecialize->getArgCount(); i++) diff --git a/source/slang/slang-ir-autodiff-transpose.h b/source/slang/slang-ir-autodiff-transpose.h index 8a734446d..910c23708 100644 --- a/source/slang/slang-ir-autodiff-transpose.h +++ b/source/slang/slang-ir-autodiff-transpose.h @@ -655,13 +655,6 @@ struct DiffTransposePass subBuilder.addBackwardDerivativePrimalReturnDecoration(branch, retVal); } - // TODO: Should move this to before all the transposition, but a lot of the - // transposition logic seems to access the parent of blocks to find the func. - // Replace those uses. - // - for (auto block : workList) - block->removeFromParent(); - // At this point, the only block left without terminator insts // should be the last one. Add a void return to complete it. // @@ -1101,7 +1094,7 @@ struct DiffTransposePass }; List writebacks; - auto baseFnType = as(baseFn->getDataType()); + auto baseFnType = as(getResolvedInstForDecorations(baseFn->getDataType())); SLANG_RELEASE_ASSERT(baseFnType); SLANG_RELEASE_ASSERT(fwdCall->getArgCount() == baseFnType->getParamCount()); @@ -1151,8 +1144,8 @@ struct DiffTransposePass auto pairType = as(arg->getDataType()); auto var = builder->emitVar(arg->getDataType()); - auto diffType = (IRType*)diffTypeContext.getDifferentialForType(builder, pairType->getValueType()); - auto zeroMethod = diffTypeContext.getZeroMethodForType(builder, pairType->getValueType()); + auto diffType = (IRType*)diffTypeContext.getDiffTypeFromPairType(builder, pairType); + auto zeroMethod = diffTypeContext.getDiffZeroMethodFromPairType(builder, pairType); SLANG_ASSERT(zeroMethod); auto diffZero = builder->emitCallInst( diffType, diff --git a/source/slang/slang-ir-autodiff-unzip.h b/source/slang/slang-ir-autodiff-unzip.h index 34f0f6c9b..63b46f779 100644 --- a/source/slang/slang-ir-autodiff-unzip.h +++ b/source/slang/slang-ir-autodiff-unzip.h @@ -210,8 +210,8 @@ struct DiffUnzipPass auto baseFn = _getOriginalFunc(mixedCall); SLANG_RELEASE_ASSERT(baseFn); - auto primalFuncType = autodiffContext->transcriberSet.primalTranscriber->differentiateFunctionType( - primalBuilder, baseFn, as(baseFn->getDataType())); + auto primalFuncType = autodiffContext->transcriberSet.primalTranscriber->transcribe( + primalBuilder, baseFn->getDataType()); IRInst* intermediateType = nullptr; @@ -251,12 +251,12 @@ struct DiffUnzipPass intermediateVar = primalBuilder->emitVar((IRType*)intermediateType); primalBuilder->markInstAsPrimal(intermediateVar); } - + IRInst* primalFn = nullptr; if (intermediateVar) { primalBuilder->addBackwardDerivativePrimalContextDecoration(intermediateVar, intermediateVar); - primalFn = primalBuilder->emitBackwardDifferentiatePrimalInst(primalFuncType, baseFn); + primalFn = primalBuilder->emitBackwardDifferentiatePrimalInst((IRType*)primalFuncType, baseFn); } else { @@ -298,7 +298,10 @@ struct DiffUnzipPass primalBuilder->addBackwardDerivativePrimalContextDecoration(primalVal, intermediateVar); primalBuilder->markInstAsPrimal(primalVal); - SLANG_RELEASE_ASSERT(mixedCall->getArgCount() <= primalFuncType->getParamCount()); + auto resolvedPrimalFuncType = as(getResolvedInstForDecorations(primalFuncType)); + SLANG_RELEASE_ASSERT(resolvedPrimalFuncType); + + SLANG_RELEASE_ASSERT(mixedCall->getArgCount() <= resolvedPrimalFuncType->getParamCount()); List diffArgs; for (UIndex ii = 0; ii < mixedCall->getArgCount(); ii++) @@ -316,7 +319,7 @@ struct DiffUnzipPass // If arg is a mixed differential (pair), it should have already been split. SLANG_ASSERT(primalArg); SLANG_ASSERT(diffArg); - auto primalParamType = primalFuncType->getParamType(ii); + auto primalParamType = resolvedPrimalFuncType->getParamType(ii); if (auto outType = as(primalParamType)) { diff --git a/source/slang/slang-ir-autodiff.cpp b/source/slang/slang-ir-autodiff.cpp index 4188d2ec8..4e33a01ab 100644 --- a/source/slang/slang-ir-autodiff.cpp +++ b/source/slang/slang-ir-autodiff.cpp @@ -458,6 +458,18 @@ IRInst* DifferentiableTypeConformanceContext::getDiffTypeWitnessFromPairType(IRB return _getDiffTypeWitnessFromPairType(sharedContext, builder, type); } +IRInst* DifferentiableTypeConformanceContext::getDiffZeroMethodFromPairType(IRBuilder* builder, IRDifferentialPairTypeBase* type) +{ + auto witnessTable = type->getWitness(); + return _lookupWitness(builder, witnessTable, sharedContext->zeroMethodStructKey); +} + +IRInst* DifferentiableTypeConformanceContext::getDiffAddMethodFromPairType(IRBuilder* builder, IRDifferentialPairTypeBase* type) +{ + auto witnessTable = type->getWitness(); + return _lookupWitness(builder, witnessTable, sharedContext->addMethodStructKey); +} + void DifferentiableTypeConformanceContext::buildGlobalWitnessDictionary() { for (auto globalInst : sharedContext->moduleInst->getChildren()) diff --git a/source/slang/slang-ir-autodiff.h b/source/slang/slang-ir-autodiff.h index 52cf346b3..91b45c5be 100644 --- a/source/slang/slang-ir-autodiff.h +++ b/source/slang/slang-ir-autodiff.h @@ -177,6 +177,10 @@ struct DifferentiableTypeConformanceContext IRInst* getDiffTypeWitnessFromPairType(IRBuilder* builder, IRDifferentialPairTypeBase* type); + IRInst* getDiffZeroMethodFromPairType(IRBuilder* builder, IRDifferentialPairTypeBase* type); + + IRInst* getDiffAddMethodFromPairType(IRBuilder* builder, IRDifferentialPairTypeBase* type); + // Lookup and return the 'Differential' type declared in the concrete type // in order to conform to the IDifferentiable interface. // Note that inside a generic block, this will be a witness table lookup instruction diff --git a/source/slang/slang-ir-lower-witness-lookup.cpp b/source/slang/slang-ir-lower-witness-lookup.cpp index c1ee204b0..0e46987c7 100644 --- a/source/slang/slang-ir-lower-witness-lookup.cpp +++ b/source/slang/slang-ir-lower-witness-lookup.cpp @@ -350,6 +350,7 @@ struct WitnessLookupLoweringContext { if (auto specialize = as(use->getUser())) { + builder.setInsertBefore(use->getUser()); List args; for (UInt i = 0; i < specialize->getArgCount(); i++) args.add(specialize->getArg(i)); diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index e74a57424..eefcb9eea 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -7067,6 +7067,8 @@ namespace Slang // and then destroy it (it had better have no uses!) void IRInst::removeAndDeallocate() { + removeAndDeallocateAllDecorationsAndChildren(); + if (auto module = getModule()) { if (getIROpInfo(getOp()).isHoistable()) @@ -7080,7 +7082,6 @@ namespace Slang module->getDeduplicationContext()->getInstReplacementMap().remove(this); } removeArguments(); - removeAndDeallocateAllDecorationsAndChildren(); removeFromParent(); // Run destructor to be sure... diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index d644d01c7..c8a41c7c7 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -7429,7 +7429,12 @@ struct DeclLoweringVisitor : DeclVisitor } else { - if (auto callableDecl = as(requirementDecl)) + CallableDecl* callableDecl = nullptr; + if (auto genDecl = as(requirementDecl)) + callableDecl = as(genDecl->inner); + else + callableDecl = as(requirementDecl); + if (callableDecl) { // Differentiable functions has additional requirements for the derivatives. for (auto diffDecl : callableDecl->getMembersOfType()) @@ -8369,7 +8374,12 @@ struct DeclLoweringVisitor : DeclVisitor LoweredValInfo lowerFuncDeclInContext(IRGenContext* subContext, IRBuilder* subBuilder, FunctionDeclBase* decl, bool emitBody = true) { - auto outerGeneric = emitOuterGenerics(subContext, decl, decl); + IRGeneric* outerGeneric = nullptr; + + if (auto derivativeRequirement = as(decl)) + outerGeneric = emitOuterGenerics(subContext, derivativeRequirement->originalRequirementDecl, derivativeRequirement->originalRequirementDecl); + else + outerGeneric = emitOuterGenerics(subContext, decl, decl); // need to create an IR function here -- cgit v1.2.3