diff options
| author | Yong He <yonghe@outlook.com> | 2022-11-23 09:39:08 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-11-23 09:39:08 -0800 |
| commit | 97cb4851eed7a43f10196971b08d3d311386ce9f (patch) | |
| tree | 99ba81368068b3345fa23b749108265aa753ed2b | |
| parent | 6178cb601368e977c4aa82e0ae25b8eb1e875d84 (diff) | |
Autodiff through simple dynamic dispatch. (#2527)
* Autodiff through simple dynamic dispatch.
* Revert changes.
* Fix.
Co-authored-by: Yong He <yhe@nvidia.com>
| -rw-r--r-- | source/slang/slang-ast-decl.h | 19 | ||||
| -rw-r--r-- | source/slang/slang-ast-modifier.h | 19 | ||||
| -rw-r--r-- | source/slang/slang-ast-val.cpp | 40 | ||||
| -rw-r--r-- | source/slang/slang-ast-val.h | 23 | ||||
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 108 | ||||
| -rw-r--r-- | source/slang/slang-ir-autodiff-fwd.cpp | 204 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 8 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 19 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 80 | ||||
| -rw-r--r-- | source/slang/slang-mangle.cpp | 4 | ||||
| -rw-r--r-- | source/slang/slang-syntax.h | 4 | ||||
| -rw-r--r-- | tests/autodiff/dynamic-dispatch-autodiff-simple.slang | 48 | ||||
| -rw-r--r-- | tests/autodiff/dynamic-dispatch-autodiff-simple.slang.expected.txt | 6 |
13 files changed, 522 insertions, 60 deletions
diff --git a/source/slang/slang-ast-decl.h b/source/slang/slang-ast-decl.h index 4da832d11..e7dc73a85 100644 --- a/source/slang/slang-ast-decl.h +++ b/source/slang/slang-ast-decl.h @@ -518,6 +518,25 @@ class AttributeDecl : public ContainerDecl SyntaxClass<NodeBase> syntaxClass; }; +// A synthesized decl used as a placeholder for a differentiable function requirement. +// This allows us to form an interface requirement key for the derivative of an interface function. +// The synthesized `DerivativeRequirementDecl` will be a child of the original function requirement +// decl after an interface type is checked. +class DerivativeRequirementDecl : public FunctionDeclBase +{ + SLANG_AST_CLASS(DerivativeRequirementDecl) +}; + +class ForwardDerivativeRequirementDecl : public DerivativeRequirementDecl +{ + SLANG_AST_CLASS(ForwardDerivativeRequirementDecl) +}; + +class BackwardDerivativeRequirementDecl : public DerivativeRequirementDecl +{ + SLANG_AST_CLASS(BackwardDerivativeRequirementDecl) +}; + bool isInterfaceRequirement(Decl* decl); } // namespace Slang diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index 04af66b50..69f39efb6 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -1085,6 +1085,25 @@ class BackwardDifferentiableAttribute : public DifferentiableAttribute SLANG_AST_CLASS(BackwardDifferentiableAttribute) }; + /// The `[BackwardDerivative(function)]` attribute specifies a custom function that should + /// be used as the backward-derivative for the decorated function. +class BackwardDerivativeAttribute : public DifferentiableAttribute +{ + SLANG_AST_CLASS(BackwardDerivativeAttribute) + Expr* funcExpr; +}; + + /// The `[BackwardDerivativeOf(primalFunction)]` attribute marks the decorated function as custom + /// backward-derivative implementation for `primalFunction`. +class BackwardDerivativeOfAttribute : public DifferentiableAttribute +{ + SLANG_AST_CLASS(BackwardDerivativeOfAttribute) + + Expr* funcExpr; + + Expr* backDeclRef; // DeclRef to this derivative function when initiated from primalFunction. +}; + /// Indicates that the modified declaration is one of the "magic" declarations /// that NVAPI uses to communicate extended operations. When NVAPI is being included /// via the prelude for downstream compilation, declarations with this modifier diff --git a/source/slang/slang-ast-val.cpp b/source/slang/slang-ast-val.cpp index 87e89ef18..a0f0552c6 100644 --- a/source/slang/slang-ast-val.cpp +++ b/source/slang/slang-ast-val.cpp @@ -1516,4 +1516,44 @@ Val* WitnessLookupIntVal::tryFold(ASTBuilder* astBuilder, SubtypeWitness* witnes return witnessResult; } + +bool DifferentiateVal::_equalsValOverride(Val* val) +{ + if (auto other = as<DifferentiateVal>(val)) + { + return other->astNodeType == astNodeType && other->func == func; + } + return false; +} + +void DifferentiateVal::_toTextOverride(StringBuilder& out) +{ + out << "DifferentiateVal("; + out << func; + out << ")"; +} + +HashCode DifferentiateVal::_getHashCodeOverride() +{ + HashCode result = (HashCode)astNodeType; + result = combineHash(result, func.getHashCode()); + return result; +} + +Val* DifferentiateVal::_substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff) +{ + int diff = 0; + auto newFunc = func.substituteImpl(astBuilder, subst, &diff); + *ioDiff += diff; + if (diff) + { + auto result = as<DifferentiateVal>(astBuilder->createByNodeType(astNodeType)); + result->func = newFunc; + return result; + } + // Nothing found: don't substitute. + return this; +} + + } // namespace Slang diff --git a/source/slang/slang-ast-val.h b/source/slang/slang-ast-val.h index b52984f8b..31b74a499 100644 --- a/source/slang/slang-ast-val.h +++ b/source/slang/slang-ast-val.h @@ -490,4 +490,27 @@ class SNormModifierVal : public ResourceFormatModifierVal Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); }; + /// Represents the result of differentiating a function. +class DifferentiateVal : public Val +{ + SLANG_AST_CLASS(DifferentiateVal) + + DeclRef<Decl> func; + + bool _equalsValOverride(Val* val); + void _toTextOverride(StringBuilder& out); + HashCode _getHashCodeOverride(); + Val* _substituteImplOverride(ASTBuilder* astBuilder, SubstitutionSet subst, int* ioDiff); +}; + +class ForwardDifferentiateVal : public DifferentiateVal +{ + SLANG_AST_CLASS(ForwardDifferentiateVal) +}; + +class BackwardDifferentiateVal : public DifferentiateVal +{ + SLANG_AST_CLASS(BackwardDifferentiateVal) +}; + } // namespace Slang diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 5a1218abe..36a1061c9 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -107,6 +107,9 @@ namespace Slang void visitAccessorDecl(AccessorDecl* decl); void visitSetterDecl(SetterDecl* decl); + + void cloneModifiers(Decl* dest, Decl* src); + void setFuncTypeIntoRequirementDecl(CallableDecl* decl, FuncType* funcType); }; struct SemanticsDeclRedeclarationVisitor @@ -1866,6 +1869,32 @@ namespace Slang return false; } + bool hasBackwardDerivative = false; + bool hasForwardDerivative = false; + if (requiredMemberDeclRef.getDecl()->hasModifier<BackwardDifferentiableAttribute>()) + { + if (!satisfyingMemberDeclRef.getDecl()->hasModifier<BackwardDifferentiableAttribute>() + && !satisfyingMemberDeclRef.getDecl()->hasModifier<BackwardDerivativeAttribute>()) + { + // A non-`BackwardDifferentiable` method can't satisfy a `BackwardDifferentiable` requirement and vice versa. + return false; + } + hasBackwardDerivative = true; + hasForwardDerivative = true; + } + else if (requiredMemberDeclRef.getDecl()->hasModifier<ForwardDifferentiableAttribute>()) + { + if (!satisfyingMemberDeclRef.getDecl()->hasModifier<ForwardDifferentiableAttribute>() + && !satisfyingMemberDeclRef.getDecl()->hasModifier<ForwardDerivativeAttribute>() + && !satisfyingMemberDeclRef.getDecl()->hasModifier<BackwardDifferentiableAttribute>() + && !satisfyingMemberDeclRef.getDecl()->hasModifier<BackwardDerivativeAttribute>()) + { + // A non-`ForwardDifferentiable` method can't satisfy a `ForwardDifferentiable` requirement and vice versa. + return false; + } + hasForwardDerivative = true; + } + // A signature matches the required one if it has the right number of parameters, // and those parameters have the right types, and also the result/return type // is the required one. @@ -1896,6 +1925,24 @@ namespace Slang witnessTable->add( requiredMemberDeclRef.getDecl(), RequirementWitness(satisfyingMemberDeclRef)); + + if (hasForwardDerivative) + { + auto reqDecl = requiredMemberDeclRef.getDecl()->getMembersOfType<ForwardDerivativeRequirementDecl>(); + SLANG_RELEASE_ASSERT(reqDecl.isNonEmpty()); + ForwardDifferentiateVal* val = m_astBuilder->create<ForwardDifferentiateVal>(); + val->func = satisfyingMemberDeclRef; + witnessTable->add(reqDecl.getFirst(), RequirementWitness(val)); + } + + if (hasBackwardDerivative) + { + auto reqDecl = requiredMemberDeclRef.getDecl()->getMembersOfType<BackwardDerivativeRequirementDecl>(); + SLANG_RELEASE_ASSERT(reqDecl.isNonEmpty()); + BackwardDifferentiateVal* val = m_astBuilder->create<BackwardDifferentiateVal>(); + val->func = satisfyingMemberDeclRef; + witnessTable->add(reqDecl.getFirst(), RequirementWitness(val)); + } return true; } @@ -5515,6 +5562,43 @@ namespace Slang } } + void SemanticsDeclHeaderVisitor::cloneModifiers(Decl* dest, Decl* src) + { + dest->modifiers = src->modifiers; + } + void SemanticsDeclHeaderVisitor::setFuncTypeIntoRequirementDecl(CallableDecl* decl, FuncType* funcType) + { + if (!funcType) + return; + decl->returnType.type = funcType->getResultType(); + decl->errorType.type = funcType->getErrorType(); + for (UInt i = 0; i < funcType->getParamCount(); i++) + { + auto paramType = funcType->getParamType(i); + if (auto dirType = as<ParamDirectionType>(paramType)) + paramType = dirType->getValueType(); + auto param = m_astBuilder->create<ParamDecl>(); + param->type.type = paramType; + auto paramDir = funcType->getParamDirection(i); + switch (paramDir) + { + case ParameterDirection::kParameterDirection_InOut: + addModifier(param, m_astBuilder->create<InOutModifier>()); + break; + case ParameterDirection::kParameterDirection_Out: + addModifier(param, m_astBuilder->create<OutModifier>()); + break; + case ParameterDirection::kParameterDirection_Ref: + addModifier(param, m_astBuilder->create<RefModifier>()); + break; + default: + break; + } + decl->members.add(param); + param->parentDecl = decl; + } + } + void SemanticsDeclHeaderVisitor::checkCallableDeclCommon(CallableDecl* decl) { for(auto paramDecl : decl->getParameters()) @@ -5532,6 +5616,30 @@ namespace Slang errorType = TypeExp(m_astBuilder->getBottomType()); } decl->errorType = errorType; + + if (isInterfaceRequirement(decl)) + { + if (decl->hasModifier<ForwardDifferentiableAttribute>()) + { + auto reqDecl = m_astBuilder->create<ForwardDerivativeRequirementDecl>(); + cloneModifiers(reqDecl, decl); + auto declRef = DeclRef<CallableDecl>(decl, createDefaultSubstitutions(m_astBuilder, this, decl)); + auto diffFuncType = getForwardDiffFuncType(getFuncType(m_astBuilder, declRef)); + setFuncTypeIntoRequirementDecl(reqDecl, as<FuncType>(diffFuncType)); + decl->members.add(reqDecl); + reqDecl->parentDecl = decl; + } + if (decl->hasModifier<BackwardDifferentiableAttribute>()) + { + auto reqDecl = m_astBuilder->create<BackwardDerivativeRequirementDecl>(); + cloneModifiers(reqDecl, decl); + auto declRef = DeclRef<CallableDecl>(decl, createDefaultSubstitutions(m_astBuilder, this, decl)); + auto diffFuncType = getBackwardDiffFuncType(getFuncType(m_astBuilder, declRef)); + setFuncTypeIntoRequirementDecl(reqDecl, as<FuncType>(diffFuncType)); + decl->members.add(reqDecl); + reqDecl->parentDecl = decl; + } + } } void SemanticsDeclHeaderVisitor::visitFuncDecl(FuncDecl* funcDecl) diff --git a/source/slang/slang-ir-autodiff-fwd.cpp b/source/slang/slang-ir-autodiff-fwd.cpp index 03e81c5b5..04a898ea9 100644 --- a/source/slang/slang-ir-autodiff-fwd.cpp +++ b/source/slang/slang-ir-autodiff-fwd.cpp @@ -100,8 +100,19 @@ struct JVPTranscriber return instMapD.ContainsKey(origInst); } + bool shouldUseOriginalAsPrimal(IRInst* origInst) + { + if (as<IRGlobalValueWithCode>(origInst)) + return true; + if (origInst->parent && origInst->parent->getOp() == kIROp_Module) + return true; + return false; + } + IRInst* lookupPrimalInst(IRInst* origInst) { + if (shouldUseOriginalAsPrimal(origInst)) + return origInst; return cloneEnv.mapOldValToNew[origInst]; } @@ -112,9 +123,11 @@ struct JVPTranscriber bool hasPrimalInst(IRInst* origInst) { + if (shouldUseOriginalAsPrimal(origInst)) + return true; return cloneEnv.mapOldValToNew.ContainsKey(origInst); } - + IRInst* findOrTranscribeDiffInst(IRBuilder* builder, IRInst* origInst) { if (!hasDifferentialInst(origInst)) @@ -128,6 +141,9 @@ struct JVPTranscriber IRInst* findOrTranscribePrimalInst(IRBuilder* builder, IRInst* origInst) { + if (shouldUseOriginalAsPrimal(origInst)) + return origInst; + if (!hasPrimalInst(origInst)) { transcribe(builder, origInst); @@ -687,7 +703,10 @@ struct JVPTranscriber IRInst* diffCallee = nullptr; - if (auto derivativeReferenceDecor = primalCallee->findDecoration<IRForwardDerivativeDecoration>()) + if (instMapD.TryGetValue(origCallee, diffCallee)) + { + } + else if (auto derivativeReferenceDecor = primalCallee->findDecoration<IRForwardDerivativeDecoration>()) { // If the user has already provided an differentiated implementation, use that. diffCallee = derivativeReferenceDecor->getForwardDerivativeFunc(); @@ -901,48 +920,111 @@ struct JVPTranscriber return InstPair(nullptr, nullptr); } - InstPair transcribeSpecialize(IRBuilder*, IRSpecialize* origSpecialize) + IRInst* findInterfaceRequirement(IRInterfaceType* type, IRInst* key) { - // In general, we should not see any specialize insts at this stage. - // The exceptions are target intrinsics. - auto genericInnerVal = findInnerMostGenericReturnVal(as<IRGeneric>(origSpecialize->getBase())); - if (genericInnerVal->findDecoration<IRTargetIntrinsicDecoration>()) + for (UInt i = 0; i < type->getOperandCount(); i++) { - // Look for an IRForwardDerivativeDecoration on the specialize inst. - // (Normally, this would be on the inner IRFunc, but in this case only the JVP func - // can be specialized, so we put a decoration on the IRSpecialize) - // - if (auto jvpFuncDecoration = origSpecialize->findDecoration<IRForwardDerivativeDecoration>()) + if (auto req = as<IRInterfaceRequirementEntry>(type->getOperand(i))) { - auto jvpFunc = jvpFuncDecoration->getForwardDerivativeFunc(); + if (req->getRequirementKey() == key) + return req->getRequirementVal(); + } + } + return nullptr; + } - // Make sure this isn't itself a specialize . - SLANG_RELEASE_ASSERT(!as<IRSpecialize>(jvpFunc)); + InstPair transcribeSpecialize(IRBuilder* builder, IRSpecialize* origSpecialize) + { + auto primalBase = findOrTranscribePrimalInst(builder, origSpecialize->getBase()); + List<IRInst*> primalArgs; + for (UInt i = 0; i < origSpecialize->getArgCount(); i++) + { + primalArgs.add(findOrTranscribePrimalInst(builder, origSpecialize->getArg(i))); + } + auto primalType = findOrTranscribePrimalInst(builder, origSpecialize->getFullType()); + auto primalSpecialize = (IRSpecialize*)builder->emitSpecializeInst( + (IRType*)primalType, primalBase, primalArgs.getCount(), primalArgs.getBuffer()); - return InstPair(jvpFunc, jvpFunc); + IRInst* diffBase = nullptr; + if (instMapD.TryGetValue(origSpecialize->getBase(), diffBase)) + { + List<IRInst*> args; + for (UInt i = 0; i < primalSpecialize->getArgCount(); i++) + { + args.add(primalSpecialize->getArg(i)); } + auto diffSpecialize = builder->emitSpecializeInst( + builder->getTypeKind(), diffBase, args.getCount(), args.getBuffer()); + return InstPair(primalSpecialize, diffSpecialize); + } + + auto genericInnerVal = findInnerMostGenericReturnVal(as<IRGeneric>(origSpecialize->getBase())); + // Look for an IRForwardDerivativeDecoration on the specialize inst. + // (Normally, this would be on the inner IRFunc, but in this case only the JVP func + // can be specialized, so we put a decoration on the IRSpecialize) + // + if (auto jvpFuncDecoration = origSpecialize->findDecoration<IRForwardDerivativeDecoration>()) + { + auto jvpFunc = jvpFuncDecoration->getForwardDerivativeFunc(); + + // Make sure this isn't itself a specialize . + SLANG_RELEASE_ASSERT(!as<IRSpecialize>(jvpFunc)); + + return InstPair(primalSpecialize, jvpFunc); + } + else if (auto derivativeDecoration = genericInnerVal->findDecoration<IRForwardDerivativeDecoration>()) + { + diffBase = derivativeDecoration->getForwardDerivativeFunc(); + List<IRInst*> args; + for (UInt i = 0; i < primalSpecialize->getArgCount(); i++) + { + args.add(primalSpecialize->getArg(i)); + } + auto diffSpecialize = builder->emitSpecializeInst( + builder->getTypeKind(), diffBase, args.getCount(), args.getBuffer()); + return InstPair(primalSpecialize, diffSpecialize); } else { - getSink()->diagnose(origSpecialize->sourceLoc, - Diagnostics::unexpected, - "should not be attempting to differentiate anything specialized here."); + return InstPair(primalSpecialize, nullptr); } - - return InstPair(nullptr, nullptr); } - InstPair transcibeLookupInterfaceMethod(IRBuilder* builder, IRLookupWitnessMethod* origLookup) + InstPair transcribeLookupInterfaceMethod(IRBuilder* builder, IRLookupWitnessMethod* lookupInst) { - // This is slightly counter-intuitive, but we don't perform any differentiation - // logic here. We simple clone the original lookup which points to the original function, - // or the cloned version in case we're inside a generic scope. - // The differentiation logic is inserted later when this is used in an IRCall. - // This decision is mostly to maintain a uniform convention of ForwardDifferentiate(Lookup(Table)) - // rather than have Lookup(ForwardDifferentiate(Table)) - // - auto diffLookup = cloneInst(&cloneEnv, builder, origLookup); - return InstPair(diffLookup, diffLookup); + auto primalWt = findOrTranscribePrimalInst(builder, lookupInst->getWitnessTable()); + auto primalKey = findOrTranscribePrimalInst(builder, lookupInst->getRequirementKey()); + auto primalType = findOrTranscribePrimalInst(builder, lookupInst->getFullType()); + auto primal = (IRSpecialize*)builder->emitLookupInterfaceMethodInst((IRType*)primalType, primalWt, primalKey); + + auto interfaceType = as<IRInterfaceType>(as<IRWitnessTableTypeBase>(lookupInst->getWitnessTable()->getDataType())->getConformanceType()); + if (!interfaceType) + { + return InstPair(primal, nullptr); + } + auto dict = interfaceType->findDecoration<IRDifferentiableMethodRequirementDictionaryDecoration>(); + if (!dict) + { + return InstPair(primal, nullptr); + } + + for (auto child : dict->getChildren()) + { + if (auto item = as<IRForwardDifferentiableMethodRequirementDictionaryItem>(child)) + { + if (item->getOperand(0) == lookupInst->getRequirementKey()) + { + auto diffKey = item->getOperand(1); + if (auto diffType = findInterfaceRequirement(interfaceType, diffKey)) + { + auto diff = builder->emitLookupInterfaceMethodInst((IRType*)diffType, primalWt, diffKey); + return InstPair(primal, diff); + } + break; + } + } + } + return InstPair(primal, nullptr); } // In differential computation, the 'default' differential value is always zero. @@ -1188,6 +1270,12 @@ struct JVPTranscriber return InstPair(primalPair, diffPair); } + InstPair trascribeNonDiffInst(IRBuilder* builder, IRInst* origInst) + { + auto primal = cloneInst(&cloneEnv, builder, origInst); + return InstPair(primal, nullptr); + } + InstPair transcribeDifferentialPairGetElement(IRBuilder* builder, IRInst* origInst) { SLANG_ASSERT( @@ -1335,6 +1423,7 @@ struct JVPTranscriber // instsInProgress.Add(origInst); InstPair pair = transcribeInst(builder, origInst); + instsInProgress.Remove(origInst); if (auto primalInst = pair.primal) { @@ -1363,11 +1452,9 @@ struct JVPTranscriber } return pair.differential; } - instsInProgress.Remove(origInst); - getSink()->diagnose(origInst->sourceLoc, - Diagnostics::internalCompilerError, - "failed to transcibe instruction"); + Diagnostics::internalCompilerError, + "failed to transcibe instruction"); return nullptr; } @@ -1407,7 +1494,10 @@ struct JVPTranscriber case kIROp_Construct: return transcribeConstruct(builder, origInst); - + + case kIROp_lookup_interface_method: + return transcribeLookupInterfaceMethod(builder, as<IRLookupWitnessMethod>(origInst)); + case kIROp_Call: return transcribeCall(builder, as<IRCall>(origInst)); @@ -1429,9 +1519,6 @@ struct JVPTranscriber case kIROp_Specialize: return transcribeSpecialize(builder, as<IRSpecialize>(origInst)); - case kIROp_lookup_interface_method: - return transcibeLookupInterfaceMethod(builder, as<IRLookupWitnessMethod>(origInst)); - case kIROp_FieldExtract: case kIROp_FieldAddress: return transcribeFieldExtract(builder, origInst); @@ -1450,6 +1537,15 @@ struct JVPTranscriber case kIROp_DifferentialPairGetPrimal: case kIROp_DifferentialPairGetDifferential: return transcribeDifferentialPairGetElement(builder, origInst); + case kIROp_ExtractExistentialWitnessTable: + case kIROp_ExtractExistentialType: + case kIROp_ExtractExistentialValue: + case kIROp_WrapExistential: + case kIROp_MakeExistential: + case kIROp_MakeExistentialWithRTTI: + return trascribeNonDiffInst(builder, origInst); + case kIROp_StructKey: + return InstPair(origInst, nullptr); } // If none of the cases have been hit, check if the instruction is a @@ -1496,7 +1592,6 @@ struct JVPTranscriber return transcribeGeneric(builder, as<IRGeneric>(origInst)); } - // If we reach this statement, the instruction type is likely unhandled. getSink()->diagnose(origInst->sourceLoc, Diagnostics::unimplemented, @@ -1558,6 +1653,7 @@ struct ForwardDerivativePass : public InstPassBase { case kIROp_Func: case kIROp_Specialize: + case kIROp_lookup_interface_method: autoDiffWorkList.add(inst); break; default: @@ -1587,29 +1683,15 @@ struct ForwardDerivativePass : public InstPassBase differentiateInst->replaceUsesWith(existingDiffFunc); differentiateInst->removeAndDeallocate(); } - else if (isMarkedForForwardDifferentiation(baseInst)) - { - if (as<IRFunc>(baseInst) || as<IRGeneric>(baseInst)) - { - IRInst* diffFunc = transcriberStorage.transcribe(builder, baseInst); - SLANG_ASSERT(diffFunc); - differentiateInst->replaceUsesWith(diffFunc); - differentiateInst->removeAndDeallocate(); - } - else - { - getSink()->diagnose(differentiateInst->sourceLoc, - Diagnostics::internalCompilerError, - "Unexpected instruction. Expected func or generic"); - } - } else { - getSink()->diagnose(differentiateInst->sourceLoc, - Diagnostics::internalCompilerError, - "Requested differentiation on a function that isn't marked as differentiable."); + IRBuilder subBuilder(*builder); + subBuilder.setInsertBefore(differentiateInst); + IRInst* diffFunc = transcriberStorage.transcribe(&subBuilder, baseInst); + SLANG_ASSERT(diffFunc); + differentiateInst->replaceUsesWith(diffFunc); + differentiateInst->removeAndDeallocate(); } - } } // Actually synthesize the derivatives. @@ -1638,6 +1720,8 @@ struct ForwardDerivativePass : public InstPassBase // bool isMarkedForForwardDifferentiation(IRInst* callable) { + if (auto gen = as<IRGeneric>(callable)) + callable = findGenericReturnVal(gen); return callable->findDecoration<IRForwardDifferentiableDecoration>() != nullptr; } diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index e11f98dcd..cc5261d14 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -746,6 +746,9 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) /* Differentiable Type Dictionary */ INST(DifferentiableTypeDictionaryDecoration, DifferentiableTypeDictionaryDecoration, 0, PARENT) + /// Decorates an interface type and stores the mapping from a normal function requirement key to its derivative requirement key. + INST(DifferentiableMethodRequirementDictionaryDecoration, DifferentiableMethodRequirementDictionaryDecoration, 0, PARENT) + /// Marks a struct type as being used as a structured buffer block. /// Recognized by SPIRV-emit pass so we can emit a SPIRV `BufferBlock` decoration. INST(SPIRVBufferBlockDecoration, spvBufferBlock, 0, 0) @@ -846,6 +849,11 @@ INST(ExistentialTypeSpecializationDictionary, ExistentialTypeSpecializationDicti /* Differentiable Type Dictionary */ INST(DifferentiableTypeDictionaryItem, DifferentiableTypeDictionaryItem, 0, 0) +/* DifferentiableMethodRequirementDictionaryItem */ + INST(ForwardDifferentiableMethodRequirementDictionaryItem, DifferentiableMethodRequirementDictionaryItem, 0, 0) + INST(BackwardDifferentiableMethodRequirementDictionaryItem, DifferentiableMethodRequirementDictionaryItem, 0, 0) +INST_RANGE(DifferentiableMethodRequirementDictionaryItem, ForwardDifferentiableMethodRequirementDictionaryItem, BackwardDifferentiableMethodRequirementDictionaryItem) + #undef PARENT #undef USE_OTHER #undef INST_RANGE diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 4434210c9..5c0401cc2 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -659,6 +659,25 @@ struct IRDifferentiableTypeDictionaryDecoration : IRDecoration IR_LEAF_ISA(DifferentiableTypeDictionaryDecoration) }; +struct IRDifferentiableMethodRequirementDictionaryDecoration : IRDecoration +{ + IR_LEAF_ISA(DifferentiableMethodRequirementDictionaryDecoration) +}; + +struct IRDifferentiableMethodRequirementDictionaryItem : IRInst +{ + IR_PARENT_ISA(DifferentiableMethodRequirementDictionaryItem) +}; + +struct IRForwardDifferentiableMethodRequirementDictionaryItem : IRDifferentiableMethodRequirementDictionaryItem +{ + IR_LEAF_ISA(ForwardDifferentiableMethodRequirementDictionaryItem) +}; + +struct IRBackwardDifferentiableMethodRequirementDictionaryItem : IRDifferentiableMethodRequirementDictionaryItem +{ + IR_LEAF_ISA(BackwardDifferentiableMethodRequirementDictionaryItem) +}; // An instruction that specializes another IR value // (representing a generic) to a particular set of generic arguments diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 57267a9ea..a0becdafa 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -1398,6 +1398,27 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower midToSup)); } + LoweredValInfo visitForwardDifferentiateVal(ForwardDifferentiateVal* val) + { + // TODO: properly fill in type info here. + // We should consider fold all cases of witness table entries to `Val`, and make the `DeclRef` case a `DeclRefVal`. + // So that we can hold the type in `DeclRefVal`. + auto funcVal = emitDeclRef(context, val->func, context->irBuilder->getTypeKind()); + SLANG_RELEASE_ASSERT(funcVal.flavor == LoweredValInfo::Flavor::Simple); + + auto diff = getBuilder()->emitForwardDifferentiateInst(getBuilder()->getTypeKind(), funcVal.val); + return LoweredValInfo::simple(diff); + } + + LoweredValInfo visitBackwardDifferentiateVal(BackwardDifferentiateVal* val) + { + auto funcVal = emitDeclRef(context, val->func, context->irBuilder->getTypeKind()); + SLANG_RELEASE_ASSERT(funcVal.flavor == LoweredValInfo::Flavor::Simple); + + auto diff = getBuilder()->emitBackwardDifferentiateInst(getBuilder()->getTypeKind(), funcVal.val); + return LoweredValInfo::simple(diff); + } + LoweredValInfo visitDifferentialBottomSubtypeWitness(DifferentialBottomSubtypeWitness*) { return LoweredValInfo(); @@ -6786,6 +6807,24 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> return LoweredValInfo::simple(assocType); } + void insertRequirementKeyAssociation(IRInterfaceType* interfaceType, Decl* requirementDecl, IRInst* originalKey, IRInst* associatedKey) + { + auto decor = interfaceType->findDecoration<IRDifferentiableMethodRequirementDictionaryDecoration>(); + if (!decor) + { + decor = + (IRDifferentiableMethodRequirementDictionaryDecoration*) + context->irBuilder->addDecoration( + interfaceType, kIROp_DifferentiableMethodRequirementDictionaryDecoration); + } + auto op = as<ForwardDerivativeRequirementDecl>(requirementDecl) + ? kIROp_ForwardDifferentiableMethodRequirementDictionaryItem + : kIROp_BackwardDifferentiableMethodRequirementDictionaryItem; + IRInst* args[] = {originalKey, associatedKey}; + auto assoc = context->irBuilder->emitIntrinsicInst(nullptr, op, 2, args); + assoc->insertAtEnd(decor); + } + LoweredValInfo visitInterfaceDecl(InterfaceDecl* decl) { // The members of an interface will turn into the keys that will @@ -6824,6 +6863,14 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> { operandCount += associatedTypeDecl->getMembersOfType<TypeConstraintDecl>().getCount(); } + else if (auto callableDecl = as<CallableDecl>(requirementDecl)) + { + // Differentiable functions has additional requirements for the derivatives. + if (callableDecl->getMembersOfType<ForwardDerivativeRequirementDecl>().getCount()) + operandCount++; + if (callableDecl->getMembersOfType<BackwardDerivativeRequirementDecl>().getCount()) + operandCount++; + } } // Allocate an IRInterfaceType with the `operandCount` operands. @@ -6907,6 +6954,38 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> } else { + if (auto callableDecl = as<CallableDecl>(requirementDecl)) + { + // Differentiable functions has additional requirements for the derivatives. + for (auto diffDecl : callableDecl->getMembersOfType<DerivativeRequirementDecl>()) + { + auto diffKey = getInterfaceRequirementKey(diffDecl); + IRInst* diffVal = ensureDecl(subContext, diffDecl).val; + auto diffEntry = subBuilder->createInterfaceRequirementEntry(diffKey, diffVal); + if (diffVal) + { + switch (diffVal->getOp()) + { + case kIROp_Func: + case kIROp_Generic: + { + // Remove lowered `IRFunc`s since we only care about + // function types. + auto reqType = diffVal->getFullType(); + diffEntry->setRequirementVal(reqType); + break; + } + default: + break; + } + } + irInterface->setOperand(entryIndex, diffEntry); + entryIndex++; + + setValue(context, diffDecl, LoweredValInfo::simple(diffEntry)); + insertRequirementKeyAssociation(irInterface, diffDecl, requirementKey, diffKey); + } + } // Add lowered requirement entry to current decl mapping to prevent // the function requirements from being lowered again when we get to // `ensureAllDeclsRec`. @@ -6914,6 +6993,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> } } + addNameHint(context, irInterface, decl); addLinkageDecoration(context, irInterface, decl); if (auto anyValueSizeAttr = decl->findModifier<AnyValueSizeAttribute>()) diff --git a/source/slang/slang-mangle.cpp b/source/slang/slang-mangle.cpp index bdb5465e9..6ea9ea01e 100644 --- a/source/slang/slang-mangle.cpp +++ b/source/slang/slang-mangle.cpp @@ -517,6 +517,10 @@ namespace Slang emitQualifiedName(context, innerDecl); return; } + else if (as<ForwardDerivativeRequirementDecl>(decl)) + emitRaw(context, "FwdReq_"); + else if (as<BackwardDerivativeRequirementDecl>(decl)) + emitRaw(context, "BwdReq_"); else { // TODO: handle other cases diff --git a/source/slang/slang-syntax.h b/source/slang/slang-syntax.h index 8f88ddb2d..2ceb7a9fd 100644 --- a/source/slang/slang-syntax.h +++ b/source/slang/slang-syntax.h @@ -309,6 +309,10 @@ namespace Slang GenericSubstitution* findInnerMostGenericSubstitution(Substitutions* subst); + ThisTypeSubstitution* findThisTypeSubstitution( + const Substitutions* substs, + InterfaceDecl* interfaceDecl); + enum class UserDefinedAttributeTargets { None = 0, diff --git a/tests/autodiff/dynamic-dispatch-autodiff-simple.slang b/tests/autodiff/dynamic-dispatch-autodiff-simple.slang new file mode 100644 index 000000000..1247253f9 --- /dev/null +++ b/tests/autodiff/dynamic-dispatch-autodiff-simple.slang @@ -0,0 +1,48 @@ +// Test calling differentiable function through dynamic dispatch. + +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +[anyValueSize(16)] +interface IInterface +{ + [ForwardDifferentiable] + static float calc(float x); +} + +struct A : IInterface +{ + [ForwardDifferentiable] + static float calc(float x) { return x * x * x; } +}; + +struct B : IInterface +{ + [ForwardDifferentiable] + static float calc(float x) { return x * x; } +}; + +[ForwardDifferentiable] +float sqr(IInterface obj, float x) +{ + return obj.calc(x); +} + +//TEST_INPUT: type_conformance A:IInterface = 0 +//TEST_INPUT: type_conformance B:IInterface = 1 + + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + var obj = createDynamicObject<IInterface>(dispatchThreadID.x, 0); // A + outputBuffer[0] = __fwd_diff(sqr)(obj, DifferentialPair<float>(2.0, 1.0)).d; // A.calc, expect 12 + + obj = createDynamicObject<IInterface>(dispatchThreadID.x + 1, 0); // B + outputBuffer[1] = __fwd_diff(sqr)(obj, DifferentialPair<float>(1.5, 1.0)).d; // B.calc, expect 3 + + outputBuffer[2] = __fwd_diff(obj.calc)(DifferentialPair<float>(1.5, 1.0)).d; // B.calc, expect 3 +} diff --git a/tests/autodiff/dynamic-dispatch-autodiff-simple.slang.expected.txt b/tests/autodiff/dynamic-dispatch-autodiff-simple.slang.expected.txt new file mode 100644 index 000000000..1b1844a5d --- /dev/null +++ b/tests/autodiff/dynamic-dispatch-autodiff-simple.slang.expected.txt @@ -0,0 +1,6 @@ +type: float +12.000000 +3.000000 +3.000000 +0.000000 +0.000000 |
