diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2022-10-26 22:21:29 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-10-26 19:21:29 -0700 |
| commit | f7f0dcadd3b2aca4c0bcd03a96e11c617cf69fc2 (patch) | |
| tree | 574dff2bcb8c5a3de9e74d18346a424c82d62a7a | |
| parent | 939be44ca23476e622dfb24a592383fe2a1da61f (diff) | |
Adding a differentiable standard library (#2465)
| -rw-r--r-- | source/slang/diff.meta.slang | 63 | ||||
| -rw-r--r-- | source/slang/slang-ast-expr.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-check-conformance.cpp | 1 | ||||
| -rw-r--r-- | source/slang/slang-check-impl.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-check-modifier.cpp | 97 | ||||
| -rw-r--r-- | source/slang/slang-diagnostic-defs.h | 4 | ||||
| -rw-r--r-- | source/slang/slang-ir-diff-call.cpp | 35 | ||||
| -rw-r--r-- | source/slang/slang-ir-diff-jvp.cpp | 282 | ||||
| -rw-r--r-- | source/slang/slang-ir-specialize.cpp | 69 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 20 | ||||
| -rw-r--r-- | tests/autodiff/custom-intrinsic.slang | 119 | ||||
| -rw-r--r-- | tests/autodiff/custom-intrinsic.slang.expected.txt | 6 | ||||
| -rw-r--r-- | tests/autodiff/dstdlib.slang | 37 | ||||
| -rw-r--r-- | tests/autodiff/dstdlib.slang.expected.txt | 7 |
14 files changed, 591 insertions, 153 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index f314e0487..ea204c839 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -5,7 +5,7 @@ syntax __differentiate_jvp : JVPDerivativeModifier; // Custom JVP Function reference -__attributeTarget(FuncDecl) +__attributeTarget(FunctionDeclBase) attribute_syntax [__custom_jvp(function)] : CustomJVPAttribute; /// Interface to denote types as differentiable. @@ -39,7 +39,7 @@ extension float : IDifferentiable [__unsafeForceInlineEarly] static Differential zero() { - return 0.f; + return float(0.f); } [__unsafeForceInlineEarly] @@ -151,3 +151,62 @@ struct __DifferentialPair return p(); } }; + +typealias IDFloat = IFloat & IDifferentiable; + +namespace dstd +{ + // Natural Exponent + __generic<T : IDFloat> + __target_intrinsic(hlsl) + __target_intrinsic(glsl) + __target_intrinsic(cuda, "$P_exp($0)") + __target_intrinsic(cpp, "$P_exp($0)") + __target_intrinsic(spirv_direct, "12 resultType resultId glsl450 27 _0") + [__custom_jvp(d_exp<T>)] + T exp(T x); + + __generic<T : IDFloat> + __DifferentialPair<T> d_exp(__DifferentialPair<T> dpx) + { + return __DifferentialPair<T>( + exp(dpx.p()), + T.dmul(exp(dpx.p()), dpx.d())); + } + + // Sine + __generic<T : IDFloat> + __target_intrinsic(hlsl) + __target_intrinsic(glsl) + __target_intrinsic(cuda, "$P_sin($0)") + __target_intrinsic(cpp, "$P_sin($0)") + __target_intrinsic(spirv_direct, "12 resultType resultId glsl450 13 _0") + [__custom_jvp(d_sin<T>)] + T sin(T x); + + __generic<T : IDFloat> + __DifferentialPair<T> d_sin(__DifferentialPair<T> dpx) + { + return __DifferentialPair<T>( + sin(dpx.p()), + T.dmul(cos(dpx.p()), dpx.d())); + } + + // Cosine + __generic<T : IDFloat> + __target_intrinsic(hlsl) + __target_intrinsic(glsl) + __target_intrinsic(cuda, "$P_cos($0)") + __target_intrinsic(cpp, "$P_cos($0)") + __target_intrinsic(spirv_direct, "12 resultType resultId glsl450 14 _0") + [__custom_jvp(d_cos<T>)] + T cos(T x); + + __generic<T : IDFloat> + __DifferentialPair<T> d_cos(__DifferentialPair<T> dpx) + { + return __DifferentialPair<T>( + cos(dpx.p()), + T.dmul(-sin(dpx.p()), dpx.d())); + } +}; diff --git a/source/slang/slang-ast-expr.h b/source/slang/slang-ast-expr.h index baa6de73a..fca628a49 100644 --- a/source/slang/slang-ast-expr.h +++ b/source/slang/slang-ast-expr.h @@ -20,8 +20,8 @@ class DeclRefExpr: public Expr // The declaration of the symbol being referenced - DeclRef<Decl> declRef; + DeclRef<Decl> declRef; // The name of the symbol being referenced Name* name = nullptr; diff --git a/source/slang/slang-check-conformance.cpp b/source/slang/slang-check-conformance.cpp index cf362dcdd..2c9977082 100644 --- a/source/slang/slang-check-conformance.cpp +++ b/source/slang/slang-check-conformance.cpp @@ -353,6 +353,7 @@ namespace Slang for( auto constraintDeclRef : getMembersOfType<GenericTypeConstraintDecl>(genericDeclRef) ) { + ensureDecl(constraintDeclRef, DeclCheckState::CanUseBaseOfInheritanceDecl); auto sub = getSub(m_astBuilder, constraintDeclRef); auto sup = getSup(m_astBuilder, constraintDeclRef); diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index ac1d624c2..2dc08262e 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -1831,8 +1831,8 @@ namespace Slang Expr* CheckExpr(Expr* expr); - Expr* CheckInvokeExprWithCheckedOperands(InvokeExpr *expr); + Expr* CheckInvokeExprWithCheckedOperands(InvokeExpr *expr); // Get the type to use when referencing a declaration QualType GetTypeForDeclRef(DeclRef<Decl> declRef, SourceLoc loc); diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp index 7e11ee3ca..20e5d5378 100644 --- a/source/slang/slang-check-modifier.cpp +++ b/source/slang/slang-check-modifier.cpp @@ -628,13 +628,102 @@ namespace Slang else if (auto customJVPAttr = as<CustomJVPAttribute>(attr)) { SLANG_ASSERT(attr->args.getCount() == 1); - + SLANG_ASSERT(as<Decl>(attrTarget)); + // Ensure that the argument is a reference to a function definition or declaration. - auto funcExpr = as<DeclRefExpr>(CheckTerm(attr->args[0])); - if (!as<FuncType>(funcExpr->type)) + auto diffExpr = CheckTerm(attr->args[0]); + if (diffExpr->type == getASTBuilder()->getErrorType()) + { + // Could not resolve the term. + getSink()->diagnose(diffExpr, Slang::Diagnostics::invalidCustomDerivative, as<Decl>(attrTarget)); + return false; + } + + // Either diffExpr has a function type, or it is a reference to a generic. + if (!as<FuncType>(diffExpr->type) && + !(as<DeclRefExpr>(diffExpr) && + as<DeclRefExpr>(diffExpr)->declRef.as<GenericDecl>().getDecl() != nullptr)) + { return false; + } - customJVPAttr->funcDeclRef = funcExpr; + auto diffDeclRef = as<DeclRefExpr>(diffExpr)->declRef; + + UCount genericLevels = 0; + // If we've grabbed the outer generic for some reason, + // recursively construct GenericAppExpr<...>(generic) + // and check that to get a specialized func. + // + while (diffDeclRef.as<GenericDecl>().getDecl() != nullptr) + { + // Forward to the inner decl + diffDeclRef = makeDeclRef(diffDeclRef.as<GenericDecl>().getDecl()->inner); + + // Increment counter. + genericLevels += 1; + } + + auto targetGeneric = as<GenericDecl>(as<Decl>(attrTarget)->parentDecl); + auto diffGeneric = as<GenericDecl>(diffDeclRef.getDecl()->parentDecl); + Expr* currentDiffExpr = diffExpr; + + // Go back through each level, and use generic declarations in the + // target's generic scope as arguments for the diff function's generic. + // + for (UIndex ii = 0; ii < genericLevels; ii++) + { + // Nest our expression inside a GenericAppExpr + auto genericAppExpr = getASTBuilder()->create<GenericAppExpr>(); + genericAppExpr->functionExpr = currentDiffExpr; + + // Construct references to the generic args in the current scope. + // TODO: Probably an easier way to do this. + for (auto member : targetGeneric->members) + { + if (auto typeParamDecl = as<GenericTypeParamDecl>(member)) + { + genericAppExpr->arguments.add( + ConstructDeclRefExpr(makeDeclRef(typeParamDecl), nullptr, typeParamDecl->loc, nullptr)); + } + else if (auto valueParamDecl = as<GenericValueParamDecl>(member)) + { + genericAppExpr->arguments.add( + ConstructDeclRefExpr(makeDeclRef(valueParamDecl), nullptr, valueParamDecl->loc, nullptr)); + } + } + + // Set our generic-app-expr as the new expr. + currentDiffExpr = genericAppExpr; + + // Peel the generic layer. + diffGeneric = as<GenericDecl>(diffGeneric->parentDecl); + targetGeneric = as<GenericDecl>(targetGeneric->parentDecl); + } + + if ((diffGeneric == nullptr && targetGeneric != nullptr) || + (targetGeneric == nullptr && diffGeneric != nullptr)) + { + //getSink()->diagnose(diffDeclRef, Slang::Diagnostics::customDerivativeGenericSignatureMismatch, diffDeclRef, attrTarget); + SLANG_UNEXPECTED(""); + } + + // If we had to change currentDiffExpr, then re-check the expr. + if (!currentDiffExpr->type) + { + currentDiffExpr = CheckTerm(currentDiffExpr); + } + + // Ensure that the argument is a reference to a function definition or declaration. + auto currentDiffDeclRefExpr = as<DeclRefExpr>(currentDiffExpr); + auto currentDiffDeclRef = currentDiffDeclRefExpr->declRef; + + if (!as<FuncType>(GetTypeForDeclRef(currentDiffDeclRef, currentDiffDeclRef.getLoc()))) + { + getSink()->diagnose(currentDiffDeclRef, Slang::Diagnostics::customDerivativeNotAFunction, currentDiffDeclRef); + } + + // TODO: Can possibly just store a DeclRef (no need for DeclRefExpr) + customJVPAttr->funcDeclRef = as<DeclRefExpr>(ConstructDeclRefExpr(currentDiffDeclRef, nullptr, currentDiffDeclRefExpr->loc, diffExpr)); } else if (auto comInterfaceAttr = as<ComInterfaceAttribute>(attr)) { diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index 6e6a6f5e5..9e939e476 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -338,6 +338,10 @@ DIAGNOSTIC(31141, Error, definitionOfExternDeclMismatchesOriginalDefinition, "`e DIAGNOSTIC(31142, Error, ambiguousOriginalDefintionOfExternDecl, "`extern` decl '$0' has ambiguous original definitions.") DIAGNOSTIC(31143, Error, missingOriginalDefintionOfExternDecl, "no original definition found for `extern` decl '$0'.") +DIAGNOSTIC(31144, Error, customDerivativeNotAFunction, "$0, used as a custom derivative, is not a function") +DIAGNOSTIC(31145, Error, customDerivativeGenericSignatureMismatch, "cannot use $0 as custom derivative for $1. generic signature does not match") +DIAGNOSTIC(31146, Error, customDerivativeSignatureMismatch, "cannot use $0 as custom derivative for $1. signature does not match") +DIAGNOSTIC(31146, Error, invalidCustomDerivative, "unable to resolve custom differential for $0.") // Enums DIAGNOSTIC(32000, Error, invalidEnumTagType, "invalid tag type for 'enum': '$0'") diff --git a/source/slang/slang-ir-diff-call.cpp b/source/slang/slang-ir-diff-call.cpp index ee78246fe..3f2c6c789 100644 --- a/source/slang/slang-ir-diff-call.cpp +++ b/source/slang/slang-ir-diff-call.cpp @@ -57,22 +57,22 @@ struct DerivativeCallProcessContext // First get base function auto origCallable = derivOfInst->getBaseFn(); - IRSpecialize* specialization = nullptr; - - // If the base is a specialize inst, get the inner fn. + // Resolve the derivative function for IRJVPDifferentiate(IRSpecialize(IRFunc)) + // Check the specialize inst for a reference to the derivative fn. + // if (auto origSpecialize = as<IRSpecialize>(origCallable)) { - specialization = origSpecialize; - origCallable = origSpecialize->getBase(); + if (auto jvpSpecRefDecorator = origSpecialize->findDecoration<IRJVPDerivativeReferenceDecoration>()) + { + jvpCallable = jvpSpecRefDecorator->getJVPFunc(); + } } - // We should have either a generic or a function reference on our hands. - SLANG_ASSERT(as<IRGeneric>(origCallable) || as<IRFunc>(origCallable)); - - // Resolve the derivative function. + // Resolve the derivative function for an IRJVPDifferentiate(IRFunc) // // Check for the 'JVPDerivativeReference' decorator on the // base function. + // if (auto jvpRefDecorator = origCallable->findDecoration<IRJVPDerivativeReferenceDecoration>()) { jvpCallable = jvpRefDecorator->getJVPFunc(); @@ -80,20 +80,9 @@ struct DerivativeCallProcessContext SLANG_ASSERT(jvpCallable); - if (specialization) - { - // Replace the specialization target with the JVP func. - specialization->setOperand(0, jvpCallable); - - // Then replace the JVPDifferentiate inst with the specialization. - derivOfInst->replaceUsesWith(specialization); - } - else - { - // Substitute all uses of the 'derivativeOf' operation - // with the resolved derivative function. - derivOfInst->replaceUsesWith(jvpCallable); - } + // Substitute all uses of the 'derivativeOf' operation + // with the resolved derivative function. + derivOfInst->replaceUsesWith(jvpCallable); // Remove the 'derivativeOf' inst. derivOfInst->removeAndDeallocate(); diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp index b97556ab1..1a86506b3 100644 --- a/source/slang/slang-ir-diff-jvp.cpp +++ b/source/slang/slang-ir-diff-jvp.cpp @@ -1091,100 +1091,106 @@ struct JVPTranscriber // in the current transcription context. // InstPair transcribeCall(IRBuilder* builder, IRCall* origCall) - { - if (as<IRFunc>(origCall->getCallee())) - { - auto origCallee = origCall->getCallee(); + { + + IRInst* origCallee = origCall->getCallee(); - // Since concrete functions are globals, the primal callee is the same - // as the original callee. + if (!origCallee) + { + // Note that this can only happen if the callee is a result + // of a higher-order operation. For now, we assume that we cannot + // differentiate such calls safely. + // TODO(sai): Should probably get checked in the front-end. // - auto primalCallee = origCallee; + getSink()->diagnose(origCall->sourceLoc, + Diagnostics::internalCompilerError, + "attempting to differentiate unresolved callee"); + + return InstPair(nullptr, nullptr); + } - IRInst* diffCallee = nullptr; + // Since concrete functions are globals, the primal callee is the same + // as the original callee. + // + auto primalCallee = origCallee; - if (auto derivativeReferenceDecor = primalCallee->findDecoration<IRJVPDerivativeReferenceDecoration>()) - { - // If the user has already provided an differentiated implementation, use that. - diffCallee = derivativeReferenceDecor->getJVPFunc(); - } - else if (primalCallee->findDecoration<IRJVPDerivativeMarkerDecoration>()) + IRInst* diffCallee = nullptr; + + if (auto derivativeReferenceDecor = primalCallee->findDecoration<IRJVPDerivativeReferenceDecoration>()) + { + // If the user has already provided an differentiated implementation, use that. + diffCallee = derivativeReferenceDecor->getJVPFunc(); + } + else if (primalCallee->findDecoration<IRJVPDerivativeMarkerDecoration>()) + { + // If the function is marked for auto-diff, push a `differentiate` inst for a follow up pass + // to generate the implementation. + diffCallee = builder->emitJVPDifferentiateInst( + differentiateFunctionType(builder, as<IRFuncType>(primalCallee->getFullType())), + primalCallee); + } + else + { + // The callee is non differentiable, just return primal value with null diff value. + IRInst* primalCall = cloneInst(&cloneEnv, builder, origCall); + return InstPair(primalCall, nullptr); + } + + List<IRInst*> args; + // Go over the parameter list and create pairs for each input (if required) + for (UIndex ii = 0; ii < origCall->getArgCount(); ii++) + { + auto origArg = origCall->getArg(ii); + auto primalArg = findOrTranscribePrimalInst(builder, origArg); + SLANG_ASSERT(primalArg); + + auto primalType = primalArg->getDataType(); + auto diffArg = findOrTranscribeDiffInst(builder, origArg); + + if (!diffArg) + diffArg = getDifferentialZeroOfType(builder, primalType); + + if (auto pairType = tryGetDiffPairType(builder, primalType)) { - // If the function is marked for auto-diff, push a `differentiate` inst for a follow up pass - // to generate the implementation. - diffCallee = builder->emitJVPDifferentiateInst( - differentiateFunctionType(builder, as<IRFuncType>(primalCallee->getFullType())), - primalCallee); + // If a pair type can be formed, this must be non-null. + SLANG_RELEASE_ASSERT(diffArg); + auto diffPair = builder->emitMakeDifferentialPair(pairType, primalArg, diffArg); + args.add(diffPair); } else { - // The callee is non differentiable, just return primal value with null diff value. - IRInst* primalCall = cloneInst(&cloneEnv, builder, origCall); - return InstPair(primalCall, nullptr); + // Add original/primal argument. + args.add(primalArg); } + } + + IRType* diffReturnType = nullptr; + diffReturnType = tryGetDiffPairType(builder, origCall->getFullType()); - List<IRInst*> args; - // Go over the parameter list and create pairs for each input (if required) - for (UIndex ii = 0; ii < origCall->getArgCount(); ii++) - { - auto origArg = origCall->getArg(ii); - auto primalArg = findOrTranscribePrimalInst(builder, origArg); - SLANG_ASSERT(primalArg); - - auto primalType = primalArg->getDataType(); - auto diffArg = findOrTranscribeDiffInst(builder, origArg); - - if (!diffArg) - diffArg = getDifferentialZeroOfType(builder, primalType); - - if (auto pairType = tryGetDiffPairType(builder, primalType)) - { - // If a pair type can be formed, this must be non-null. - SLANG_RELEASE_ASSERT(diffArg); - auto diffPair = builder->emitMakeDifferentialPair(pairType, primalArg, diffArg); - args.add(diffPair); - } - else - { - // Add original/primal argument. - args.add(primalArg); - } - } - - IRType* diffReturnType = nullptr; - diffReturnType = tryGetDiffPairType(builder, origCall->getFullType()); - SLANG_ASSERT(diffReturnType); + if (!diffReturnType) + { + SLANG_RELEASE_ASSERT(origCall->getFullType()->getOp() == kIROp_VoidType); + diffReturnType = builder->getVoidType(); + } - auto callInst = builder->emitCallInst( - diffReturnType, - diffCallee, - args); + auto callInst = builder->emitCallInst( + diffReturnType, + diffCallee, + args); + if (diffReturnType->getOp() != kIROp_VoidType) + { IRInst* primalResultValue = pairBuilder->emitPrimalFieldAccess(builder, callInst); IRInst* diffResultValue = pairBuilder->emitDiffFieldAccess(builder, callInst); - return InstPair(primalResultValue, diffResultValue); } - else if(as<IRSpecialize>(origCall->getCallee()) || - as<IRLookupWitnessMethod>(origCall->getCallee())) - { - getSink()->diagnose(origCall->sourceLoc, - Diagnostics::unimplemented, - "attempting to differentiate unspecialized callee or an interface method"); - } else { - // Note that this can only happen if the callee is a result - // of a higher-order operation. For now, we assume that we cannot - // differentiate such calls safely. - // TODO(sai): Should probably get checked in the front-end. - // - getSink()->diagnose(origCall->sourceLoc, - Diagnostics::internalCompilerError, - "attempting to differentiate unresolved callee"); + // Return the inst itself if the return value is void. + // This is fine since these values should never actually be used anywhere. + // + return InstPair(callInst, callInst); } - - return InstPair(nullptr, nullptr); } InstPair transcribeSwizzle(IRBuilder* builder, IRSwizzle* origSwizzle) @@ -1314,17 +1320,35 @@ struct JVPTranscriber return InstPair(nullptr, nullptr); } - InstPair transcribeSpecialize(IRBuilder* builder, IRSpecialize* origSpecialize) + InstPair transcribeSpecialize(IRBuilder*, IRSpecialize* origSpecialize) { - // This is slightly counter-intuitive, but we don't perform any differentiation - // logic here. We simple clone the original specialize which points to the original function, - // or the cloned version in case we're inside a generic scope. - // The differentiation logic is inserted later when this is used in an IRCall. - // This decision is mostly to maintain a uniform convention of JVPDifferentiate(Specialize(Fn)) - // rather than have Specialize(JVPDifferentiate(Fn)) - // - auto diffSpecialize = cloneInst(&cloneEnv, builder, origSpecialize); - return InstPair(diffSpecialize, diffSpecialize); + // In general, we should not see any specialize insts at this stage. + // The exceptions are target intrinsics. + auto genericInnerVal = findInnerMostGenericReturnVal(as<IRGeneric>(origSpecialize->getBase())); + if (genericInnerVal->findDecoration<IRTargetIntrinsicDecoration>()) + { + // Look for an IRJVPDerivativeReferenceDecoration on the specialize inst. + // (Normally, this would be on the inner IRFunc, but in this case only the JVP func + // can be specialized, so we put a decoration on the IRSpecialize) + // + if (auto jvpFuncDecoration = origSpecialize->findDecoration<IRJVPDerivativeReferenceDecoration>()) + { + auto jvpFunc = jvpFuncDecoration->getJVPFunc(); + + // Make sure this isn't itself a specialize . + SLANG_RELEASE_ASSERT(!as<IRSpecialize>(jvpFunc)); + + return InstPair(jvpFunc, jvpFunc); + } + } + else + { + getSink()->diagnose(origSpecialize->sourceLoc, + Diagnostics::unexpected, + "should not be attempting to differentiate anything specialized here."); + } + + return InstPair(nullptr, nullptr); } InstPair transcibeLookupInterfaceMethod(IRBuilder* builder, IRLookupWitnessMethod* origLookup) @@ -1777,10 +1801,7 @@ struct JVPTranscriber return transcribeConst(builder, origInst); case kIROp_Specialize: - getSink()->diagnose(origInst->sourceLoc, - Diagnostics::unexpected, - "should not be attempting to differentiate anything specialized here."); - return InstPair(nullptr, nullptr); + return transcribeSpecialize(builder, as<IRSpecialize>(origInst)); case kIROp_lookup_interface_method: return transcibeLookupInterfaceMethod(builder, as<IRLookupWitnessMethod>(origInst)); @@ -1972,7 +1993,10 @@ struct JVPDerivativeContext if (auto specializeInst = as<IRSpecialize>(baseInst)) { - baseFunction = as<IRGlobalValueWithCode>(specializeInst->getBase()); + // Certain specialize insts come with a derivative + // reference attached. Skip such instructions. + // + if (lookupJVPReference(specializeInst)) continue; } else if (auto globalValWithCode = as<IRGlobalValueWithCode>(baseInst)) { @@ -2026,6 +2050,12 @@ struct JVPDerivativeContext { builder->setInsertBefore(pairType); + if (!as<IRType>(pairType->getValueType())) + { + // Do not handle non-concrete types. + return nullptr; + } + auto diffPairStructType = (&pairBuilderStorage)->getOrCreateDiffPairType( builder, pairType->getValueType()); @@ -2054,19 +2084,20 @@ struct JVPDerivativeContext if (auto makePairInst = as<IRMakeDifferentialPair>(inst)) { - auto diffPairStructType = lowerPairType(builder, makePairInst->getDataType(), diffContext); - - builder->setInsertBefore(makePairInst); - - List<IRInst*> operands; - operands.add(makePairInst->getPrimalValue()); - operands.add(makePairInst->getDifferentialValue()); + if (auto diffPairStructType = lowerPairType(builder, makePairInst->getDataType(), diffContext)) + { + builder->setInsertBefore(makePairInst); + + List<IRInst*> operands; + operands.add(makePairInst->getPrimalValue()); + operands.add(makePairInst->getDifferentialValue()); - auto makeStructInst = builder->emitMakeStruct((IRType*)(diffPairStructType), operands); - makePairInst->replaceUsesWith(makeStructInst); - makePairInst->removeAndDeallocate(); + auto makeStructInst = builder->emitMakeStruct((IRType*)(diffPairStructType), operands); + makePairInst->replaceUsesWith(makeStructInst); + makePairInst->removeAndDeallocate(); - return makeStructInst; + return makeStructInst; + } } return nullptr; @@ -2074,30 +2105,30 @@ struct JVPDerivativeContext IRInst* lowerPairAccess(IRBuilder* builder, IRInst* inst, DifferentiableTypeConformanceContext* diffContext) { - if (auto getDiffInst = as<IRDifferentialPairGetDifferential>(inst)) { - lowerPairType(builder, getDiffInst->getBase()->getDataType(), diffContext); - - builder->setInsertBefore(getDiffInst); - - auto diffFieldExtract = (&pairBuilderStorage)->emitDiffFieldAccess(builder, getDiffInst->getBase()); - getDiffInst->replaceUsesWith(diffFieldExtract); - getDiffInst->removeAndDeallocate(); - - return diffFieldExtract; + if (lowerPairType(builder, getDiffInst->getBase()->getDataType(), diffContext)) + { + builder->setInsertBefore(getDiffInst); + + auto diffFieldExtract = (&pairBuilderStorage)->emitDiffFieldAccess(builder, getDiffInst->getBase()); + getDiffInst->replaceUsesWith(diffFieldExtract); + getDiffInst->removeAndDeallocate(); + return diffFieldExtract; + } } else if (auto getPrimalInst = as<IRDifferentialPairGetPrimal>(inst)) { - lowerPairType(builder, getPrimalInst->getBase()->getDataType(), diffContext); - - builder->setInsertBefore(getPrimalInst); + if (lowerPairType(builder, getPrimalInst->getBase()->getDataType(), diffContext)) + { + builder->setInsertBefore(getPrimalInst); - auto primalFieldExtract = (&pairBuilderStorage)->emitPrimalFieldAccess(builder, getPrimalInst->getBase()); - getPrimalInst->replaceUsesWith(primalFieldExtract); - getPrimalInst->removeAndDeallocate(); + auto primalFieldExtract = (&pairBuilderStorage)->emitPrimalFieldAccess(builder, getPrimalInst->getBase()); + getPrimalInst->replaceUsesWith(primalFieldExtract); + getPrimalInst->removeAndDeallocate(); - return primalFieldExtract; + return primalFieldExtract; + } } return nullptr; @@ -2295,9 +2326,9 @@ bool processJVPDerivativeMarkers( return changed; } -void stripAutoDiffDecorations(IRModule* module) +void stripAutoDiffDecorationsFromChildren(IRInst* parent) { - for (auto inst : module->getGlobalInsts()) + for (auto inst : parent->getChildren()) { for (auto decor = inst->getFirstDecoration(); decor; ) { @@ -2313,7 +2344,18 @@ void stripAutoDiffDecorations(IRModule* module) } decor = next; } + + if (inst->getFirstChild() != nullptr) + { + stripAutoDiffDecorationsFromChildren(inst); + } } } +void stripAutoDiffDecorations(IRModule* module) +{ + stripAutoDiffDecorationsFromChildren(module->getModuleInst()); +} + + } diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp index 3ef79df28..53ea99a0c 100644 --- a/source/slang/slang-ir-specialize.cpp +++ b/source/slang/slang-ir-specialize.cpp @@ -373,7 +373,76 @@ struct SpecializationContext // represents a definition rather than a declaration. // if(!canSpecializeGeneric(genericVal)) + { + // We have to consider a special case here if baseVal is + // an intrinsic, and contains a custom differential. + // This is a case where the base cannot be specialized since it has + // no body, but the custom should be specialized. + // A better way to handle this would be to grab a reference to the + // appropriate custom differential, if one exists, at checking time + // during CheckInvoke() and construct it's specialization args appropriately. + // + // For now, we will overwrite the specialization args for the differential + // using the args for the base. + // + auto genericReturnVal = findInnerMostGenericReturnVal(genericVal); + if (genericReturnVal->findDecoration<IRTargetIntrinsicDecoration>()) + { + if (auto customDiffRef = genericReturnVal->findDecoration<IRJVPDerivativeReferenceDecoration>()) + { + // If we already have a diff func on this specialize, skip. + if (auto specDiffRef = specInst->findDecoration<IRJVPDerivativeReferenceDecoration>()) + { + return false; + } + + auto specDiffFunc = as<IRSpecialize>(customDiffRef->getJVPFunc()); + + // If the base is specialized, the JVP version must be also be a specialized + // generic. + // + SLANG_ASSERT(specDiffFunc); + + // Build specialization arguments from specInst. + // Note that if we've reached this point, we can safely assume + // that our args are fully specialized/concrete. + // + UCount argCount = specInst->getArgCount(); + List<IRInst*> args; + for (UIndex ii = 0; ii < argCount; ii++) + args.add(specInst->getArg(ii)); + + IRBuilder builder(&sharedBuilderStorage); + + // Specialize the custom JVP function type with the original arguments. + builder.setInsertInto(module); + auto newDiffFuncType = builder.emitSpecializeInst( + builder.getTypeKind(), + specDiffFunc->getBase()->getDataType(), + argCount, + args.getBuffer()); + + // Specialize the custom JVP function with the original arguments. + builder.setInsertBefore(specInst); + auto newDiffFunc = builder.emitSpecializeInst( + (IRType*) newDiffFuncType, + specDiffFunc->getBase(), + argCount, + args.getBuffer()); + + // Add the new spec insts to the list so they get specialized with + // the usual logic. + // + addToWorkList(newDiffFuncType); + addToWorkList(newDiffFunc); + + builder.addJVPDerivativeReferenceDecoration(specInst, newDiffFunc); + + return true; + } + } return false; + } // Once we know that specialization is possible, // the actual work is fairly simple. diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 1e58a456e..ff66caa90 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -8204,12 +8204,28 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> getBuilder()->addDecoration(irFunc, kIROp_ForceInlineDecoration); } + // Register the value now, to avoid any possible infinite recursion when lowering CustomJVPAttribute + setGlobalValue(context, decl, LoweredValInfo::simple(findOuterMostGeneric(irFunc))); + if (auto attr = decl->findModifier<CustomJVPAttribute>()) { - auto loweredVal = lowerLValueExpr(this->context, attr->funcDeclRef); + // TODO(Sai): HACK.. we need to emit a decl-ref to handle this modifier correctly. + // If we don't move the cursor to the parent, we sometimes emit supporting + // insts into the function body, which shouldn't happen. + // + subContext->irBuilder->setInsertInto(irFunc->getParent()); + + auto diffFuncType = getFuncType(subContext->astBuilder, attr->funcDeclRef->declRef.as<CallableDecl>()); + auto irDiffFuncType = lowerType(subContext, diffFuncType); + + auto loweredVal = emitDeclRef(subContext, attr->funcDeclRef->declRef, irDiffFuncType); + SLANG_ASSERT(loweredVal.flavor == LoweredValInfo::Flavor::Simple); - IRFunc* jvpFunc = as<IRFunc>(loweredVal.val); + IRInst* jvpFunc = loweredVal.val; getBuilder()->addDecoration(irFunc, kIROp_JVPDerivativeReferenceDecoration, jvpFunc); + + // Reset cursor. + subContext->irBuilder->setInsertInto(irFunc); } // For convenience, ensure that any additional global diff --git a/tests/autodiff/custom-intrinsic.slang b/tests/autodiff/custom-intrinsic.slang new file mode 100644 index 000000000..8ce354edc --- /dev/null +++ b/tests/autodiff/custom-intrinsic.slang @@ -0,0 +1,119 @@ +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +typedef __DifferentialPair<float> dpfloat; + +typealias IDFloat = IFloat & IDifferentiable; + +namespace myintrinsiclib +{ + __generic<T : IDFloat> + __target_intrinsic(hlsl) + __target_intrinsic(glsl) + __target_intrinsic(cuda, "$P_exp($0)") + __target_intrinsic(cpp, "$P_exp($0)") + __target_intrinsic(spirv_direct, "12 resultType resultId glsl450 27 _0") + [__custom_jvp(d_exp<T>)] + T exp(T x); + + __generic<T : IDFloat> + __DifferentialPair<T> d_exp(__DifferentialPair<T> dpx) + { + return __DifferentialPair<T>( + exp(dpx.p()), + T.dmul(exp(dpx.p()), dpx.d())); + } + + + // Sine + __generic<T : IDFloat> + __target_intrinsic(hlsl) + __target_intrinsic(glsl) + __target_intrinsic(cuda, "$P_sin($0)") + __target_intrinsic(cpp, "$P_sin($0)") + __target_intrinsic(spirv_direct, "12 resultType resultId glsl450 13 _0") + [__custom_jvp(d_sin<T>)] + T sin(T x); + + __generic<T : IDFloat> + __DifferentialPair<T> d_sin(__DifferentialPair<T> dpx) + { + return __DifferentialPair<T>( + sin(dpx.p()), + T.dmul(cos(dpx.p()), dpx.d())); + } + + // Cosine + __generic<T : IDFloat> + __target_intrinsic(hlsl) + __target_intrinsic(glsl) + __target_intrinsic(cuda, "$P_cos($0)") + __target_intrinsic(cpp, "$P_cos($0)") + __target_intrinsic(spirv_direct, "12 resultType resultId glsl450 14 _0") + [__custom_jvp(d_cos<T>)] + T cos(T x); + + __generic<T : IDFloat> + __DifferentialPair<T> d_cos(__DifferentialPair<T> dpx) + { + return __DifferentialPair<T>( + cos(dpx.p()), + T.dmul(-sin(dpx.p()), dpx.d())); + } + + // Sine and cosine + __generic<T : IDFloat> + __target_intrinsic(hlsl) + __target_intrinsic(cuda, "$P_sincos($0, $1, $2)") + [__custom_jvp(d_sincos<T>)] + void sincos(T x, out T s, out T c) + { + s = sin(x); + c = cos(x); + } + + __generic<T : IDFloat> + void d_sincos(__DifferentialPair<T> x, out __DifferentialPair<T> s, out __DifferentialPair<T> c) + { + T _s; + T _c; + sincos(x.p(), _s, _c); + + s = __DifferentialPair<T>(_s, T.dmul(_c, x.d())); + c = __DifferentialPair<T>(_c, T.dmul(-_s, x.d())); + } +}; + +__differentiate_jvp float f(float x) +{ + return myintrinsiclib.exp(x); +} + +__differentiate_jvp float g(float x) +{ + float s; + float t; + myintrinsiclib.sincos(x, s, t); + + return s + t; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) +{ + { + dpfloat dpa = dpfloat(2.0, 1.0); + + outputBuffer[0] = f(dpa.p()); // Expect: 7.389056 + outputBuffer[1] = __jvp(f)(dpa).d(); // Expect: 7.389056 + + // g() needs additional handling of IRMakeDifferentialPair(PtrType). This needs to + // generate a new var, load from the individual vars and store into the pair var. + + //outputBuffer[2] = g(dpa.p()); // Expect: 1.381773 + //outputBuffer[3] = __jvp(g)(dpa).d(); // Expect: -0.301168 + } +}
\ No newline at end of file diff --git a/tests/autodiff/custom-intrinsic.slang.expected.txt b/tests/autodiff/custom-intrinsic.slang.expected.txt new file mode 100644 index 000000000..ce22a5b95 --- /dev/null +++ b/tests/autodiff/custom-intrinsic.slang.expected.txt @@ -0,0 +1,6 @@ +type: float +7.389056 +7.389056 +0.0 +0.0 +0.0
\ No newline at end of file diff --git a/tests/autodiff/dstdlib.slang b/tests/autodiff/dstdlib.slang new file mode 100644 index 000000000..6c7ecffbe --- /dev/null +++ b/tests/autodiff/dstdlib.slang @@ -0,0 +1,37 @@ +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +typedef __DifferentialPair<float> dpfloat; + +__differentiate_jvp float f(float x) +{ + return dstd.exp(x); +} + +__differentiate_jvp float g(float x) +{ + return dstd.sin(x); +} + +__differentiate_jvp float h(float x) +{ + return dstd.cos(x); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID) +{ + { + dpfloat dpa = dpfloat(2.0, 1.0); + + outputBuffer[0] = f(dpa.p()); // Expect: 7.389056 + outputBuffer[1] = __jvp(f)(dpa).d(); // Expect: 7.389056 + outputBuffer[2] = g(dpa.p()); // Expect: 0.909297 + outputBuffer[3] = __jvp(g)(dpa).d(); // Expect: -0.416146 + outputBuffer[4] = h(dpa.p()); // Expect: -0.416146 + outputBuffer[5] = __jvp(h)(dpa).d(); // Expect: -0.909297 + } +}
\ No newline at end of file diff --git a/tests/autodiff/dstdlib.slang.expected.txt b/tests/autodiff/dstdlib.slang.expected.txt new file mode 100644 index 000000000..82053b379 --- /dev/null +++ b/tests/autodiff/dstdlib.slang.expected.txt @@ -0,0 +1,7 @@ +type: float +7.389056 +7.389056 +0.909297 +-0.416147 +-0.416147 +-0.909297
\ No newline at end of file |
