summaryrefslogtreecommitdiffstats
path: root/source/slang
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2022-10-26 22:21:29 -0400
committerGitHub <noreply@github.com>2022-10-26 19:21:29 -0700
commitf7f0dcadd3b2aca4c0bcd03a96e11c617cf69fc2 (patch)
tree574dff2bcb8c5a3de9e74d18346a424c82d62a7a /source/slang
parent939be44ca23476e622dfb24a592383fe2a1da61f (diff)
Adding a differentiable standard library (#2465)
Diffstat (limited to 'source/slang')
-rw-r--r--source/slang/diff.meta.slang63
-rw-r--r--source/slang/slang-ast-expr.h2
-rw-r--r--source/slang/slang-check-conformance.cpp1
-rw-r--r--source/slang/slang-check-impl.h2
-rw-r--r--source/slang/slang-check-modifier.cpp97
-rw-r--r--source/slang/slang-diagnostic-defs.h4
-rw-r--r--source/slang/slang-ir-diff-call.cpp35
-rw-r--r--source/slang/slang-ir-diff-jvp.cpp282
-rw-r--r--source/slang/slang-ir-specialize.cpp69
-rw-r--r--source/slang/slang-lower-to-ir.cpp20
10 files changed, 422 insertions, 153 deletions
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang
index f314e0487..ea204c839 100644
--- a/source/slang/diff.meta.slang
+++ b/source/slang/diff.meta.slang
@@ -5,7 +5,7 @@
syntax __differentiate_jvp : JVPDerivativeModifier;
// Custom JVP Function reference
-__attributeTarget(FuncDecl)
+__attributeTarget(FunctionDeclBase)
attribute_syntax [__custom_jvp(function)] : CustomJVPAttribute;
/// Interface to denote types as differentiable.
@@ -39,7 +39,7 @@ extension float : IDifferentiable
[__unsafeForceInlineEarly]
static Differential zero()
{
- return 0.f;
+ return float(0.f);
}
[__unsafeForceInlineEarly]
@@ -151,3 +151,62 @@ struct __DifferentialPair
return p();
}
};
+
+typealias IDFloat = IFloat & IDifferentiable;
+
+namespace dstd
+{
+ // Natural Exponent
+ __generic<T : IDFloat>
+ __target_intrinsic(hlsl)
+ __target_intrinsic(glsl)
+ __target_intrinsic(cuda, "$P_exp($0)")
+ __target_intrinsic(cpp, "$P_exp($0)")
+ __target_intrinsic(spirv_direct, "12 resultType resultId glsl450 27 _0")
+ [__custom_jvp(d_exp<T>)]
+ T exp(T x);
+
+ __generic<T : IDFloat>
+ __DifferentialPair<T> d_exp(__DifferentialPair<T> dpx)
+ {
+ return __DifferentialPair<T>(
+ exp(dpx.p()),
+ T.dmul(exp(dpx.p()), dpx.d()));
+ }
+
+ // Sine
+ __generic<T : IDFloat>
+ __target_intrinsic(hlsl)
+ __target_intrinsic(glsl)
+ __target_intrinsic(cuda, "$P_sin($0)")
+ __target_intrinsic(cpp, "$P_sin($0)")
+ __target_intrinsic(spirv_direct, "12 resultType resultId glsl450 13 _0")
+ [__custom_jvp(d_sin<T>)]
+ T sin(T x);
+
+ __generic<T : IDFloat>
+ __DifferentialPair<T> d_sin(__DifferentialPair<T> dpx)
+ {
+ return __DifferentialPair<T>(
+ sin(dpx.p()),
+ T.dmul(cos(dpx.p()), dpx.d()));
+ }
+
+ // Cosine
+ __generic<T : IDFloat>
+ __target_intrinsic(hlsl)
+ __target_intrinsic(glsl)
+ __target_intrinsic(cuda, "$P_cos($0)")
+ __target_intrinsic(cpp, "$P_cos($0)")
+ __target_intrinsic(spirv_direct, "12 resultType resultId glsl450 14 _0")
+ [__custom_jvp(d_cos<T>)]
+ T cos(T x);
+
+ __generic<T : IDFloat>
+ __DifferentialPair<T> d_cos(__DifferentialPair<T> dpx)
+ {
+ return __DifferentialPair<T>(
+ cos(dpx.p()),
+ T.dmul(-sin(dpx.p()), dpx.d()));
+ }
+};
diff --git a/source/slang/slang-ast-expr.h b/source/slang/slang-ast-expr.h
index baa6de73a..fca628a49 100644
--- a/source/slang/slang-ast-expr.h
+++ b/source/slang/slang-ast-expr.h
@@ -20,8 +20,8 @@ class DeclRefExpr: public Expr
// The declaration of the symbol being referenced
- DeclRef<Decl> declRef;
+ DeclRef<Decl> declRef;
// The name of the symbol being referenced
Name* name = nullptr;
diff --git a/source/slang/slang-check-conformance.cpp b/source/slang/slang-check-conformance.cpp
index cf362dcdd..2c9977082 100644
--- a/source/slang/slang-check-conformance.cpp
+++ b/source/slang/slang-check-conformance.cpp
@@ -353,6 +353,7 @@ namespace Slang
for( auto constraintDeclRef : getMembersOfType<GenericTypeConstraintDecl>(genericDeclRef) )
{
+ ensureDecl(constraintDeclRef, DeclCheckState::CanUseBaseOfInheritanceDecl);
auto sub = getSub(m_astBuilder, constraintDeclRef);
auto sup = getSup(m_astBuilder, constraintDeclRef);
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index ac1d624c2..2dc08262e 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -1831,8 +1831,8 @@ namespace Slang
Expr* CheckExpr(Expr* expr);
- Expr* CheckInvokeExprWithCheckedOperands(InvokeExpr *expr);
+ Expr* CheckInvokeExprWithCheckedOperands(InvokeExpr *expr);
// Get the type to use when referencing a declaration
QualType GetTypeForDeclRef(DeclRef<Decl> declRef, SourceLoc loc);
diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp
index 7e11ee3ca..20e5d5378 100644
--- a/source/slang/slang-check-modifier.cpp
+++ b/source/slang/slang-check-modifier.cpp
@@ -628,13 +628,102 @@ namespace Slang
else if (auto customJVPAttr = as<CustomJVPAttribute>(attr))
{
SLANG_ASSERT(attr->args.getCount() == 1);
-
+ SLANG_ASSERT(as<Decl>(attrTarget));
+
// Ensure that the argument is a reference to a function definition or declaration.
- auto funcExpr = as<DeclRefExpr>(CheckTerm(attr->args[0]));
- if (!as<FuncType>(funcExpr->type))
+ auto diffExpr = CheckTerm(attr->args[0]);
+ if (diffExpr->type == getASTBuilder()->getErrorType())
+ {
+ // Could not resolve the term.
+ getSink()->diagnose(diffExpr, Slang::Diagnostics::invalidCustomDerivative, as<Decl>(attrTarget));
+ return false;
+ }
+
+ // Either diffExpr has a function type, or it is a reference to a generic.
+ if (!as<FuncType>(diffExpr->type) &&
+ !(as<DeclRefExpr>(diffExpr) &&
+ as<DeclRefExpr>(diffExpr)->declRef.as<GenericDecl>().getDecl() != nullptr))
+ {
return false;
+ }
- customJVPAttr->funcDeclRef = funcExpr;
+ auto diffDeclRef = as<DeclRefExpr>(diffExpr)->declRef;
+
+ UCount genericLevels = 0;
+ // If we've grabbed the outer generic for some reason,
+ // recursively construct GenericAppExpr<...>(generic)
+ // and check that to get a specialized func.
+ //
+ while (diffDeclRef.as<GenericDecl>().getDecl() != nullptr)
+ {
+ // Forward to the inner decl
+ diffDeclRef = makeDeclRef(diffDeclRef.as<GenericDecl>().getDecl()->inner);
+
+ // Increment counter.
+ genericLevels += 1;
+ }
+
+ auto targetGeneric = as<GenericDecl>(as<Decl>(attrTarget)->parentDecl);
+ auto diffGeneric = as<GenericDecl>(diffDeclRef.getDecl()->parentDecl);
+ Expr* currentDiffExpr = diffExpr;
+
+ // Go back through each level, and use generic declarations in the
+ // target's generic scope as arguments for the diff function's generic.
+ //
+ for (UIndex ii = 0; ii < genericLevels; ii++)
+ {
+ // Nest our expression inside a GenericAppExpr
+ auto genericAppExpr = getASTBuilder()->create<GenericAppExpr>();
+ genericAppExpr->functionExpr = currentDiffExpr;
+
+ // Construct references to the generic args in the current scope.
+ // TODO: Probably an easier way to do this.
+ for (auto member : targetGeneric->members)
+ {
+ if (auto typeParamDecl = as<GenericTypeParamDecl>(member))
+ {
+ genericAppExpr->arguments.add(
+ ConstructDeclRefExpr(makeDeclRef(typeParamDecl), nullptr, typeParamDecl->loc, nullptr));
+ }
+ else if (auto valueParamDecl = as<GenericValueParamDecl>(member))
+ {
+ genericAppExpr->arguments.add(
+ ConstructDeclRefExpr(makeDeclRef(valueParamDecl), nullptr, valueParamDecl->loc, nullptr));
+ }
+ }
+
+ // Set our generic-app-expr as the new expr.
+ currentDiffExpr = genericAppExpr;
+
+ // Peel the generic layer.
+ diffGeneric = as<GenericDecl>(diffGeneric->parentDecl);
+ targetGeneric = as<GenericDecl>(targetGeneric->parentDecl);
+ }
+
+ if ((diffGeneric == nullptr && targetGeneric != nullptr) ||
+ (targetGeneric == nullptr && diffGeneric != nullptr))
+ {
+ //getSink()->diagnose(diffDeclRef, Slang::Diagnostics::customDerivativeGenericSignatureMismatch, diffDeclRef, attrTarget);
+ SLANG_UNEXPECTED("");
+ }
+
+ // If we had to change currentDiffExpr, then re-check the expr.
+ if (!currentDiffExpr->type)
+ {
+ currentDiffExpr = CheckTerm(currentDiffExpr);
+ }
+
+ // Ensure that the argument is a reference to a function definition or declaration.
+ auto currentDiffDeclRefExpr = as<DeclRefExpr>(currentDiffExpr);
+ auto currentDiffDeclRef = currentDiffDeclRefExpr->declRef;
+
+ if (!as<FuncType>(GetTypeForDeclRef(currentDiffDeclRef, currentDiffDeclRef.getLoc())))
+ {
+ getSink()->diagnose(currentDiffDeclRef, Slang::Diagnostics::customDerivativeNotAFunction, currentDiffDeclRef);
+ }
+
+ // TODO: Can possibly just store a DeclRef (no need for DeclRefExpr)
+ customJVPAttr->funcDeclRef = as<DeclRefExpr>(ConstructDeclRefExpr(currentDiffDeclRef, nullptr, currentDiffDeclRefExpr->loc, diffExpr));
}
else if (auto comInterfaceAttr = as<ComInterfaceAttribute>(attr))
{
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index 6e6a6f5e5..9e939e476 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -338,6 +338,10 @@ DIAGNOSTIC(31141, Error, definitionOfExternDeclMismatchesOriginalDefinition, "`e
DIAGNOSTIC(31142, Error, ambiguousOriginalDefintionOfExternDecl, "`extern` decl '$0' has ambiguous original definitions.")
DIAGNOSTIC(31143, Error, missingOriginalDefintionOfExternDecl, "no original definition found for `extern` decl '$0'.")
+DIAGNOSTIC(31144, Error, customDerivativeNotAFunction, "$0, used as a custom derivative, is not a function")
+DIAGNOSTIC(31145, Error, customDerivativeGenericSignatureMismatch, "cannot use $0 as custom derivative for $1. generic signature does not match")
+DIAGNOSTIC(31146, Error, customDerivativeSignatureMismatch, "cannot use $0 as custom derivative for $1. signature does not match")
+DIAGNOSTIC(31146, Error, invalidCustomDerivative, "unable to resolve custom differential for $0.")
// Enums
DIAGNOSTIC(32000, Error, invalidEnumTagType, "invalid tag type for 'enum': '$0'")
diff --git a/source/slang/slang-ir-diff-call.cpp b/source/slang/slang-ir-diff-call.cpp
index ee78246fe..3f2c6c789 100644
--- a/source/slang/slang-ir-diff-call.cpp
+++ b/source/slang/slang-ir-diff-call.cpp
@@ -57,22 +57,22 @@ struct DerivativeCallProcessContext
// First get base function
auto origCallable = derivOfInst->getBaseFn();
- IRSpecialize* specialization = nullptr;
-
- // If the base is a specialize inst, get the inner fn.
+ // Resolve the derivative function for IRJVPDifferentiate(IRSpecialize(IRFunc))
+ // Check the specialize inst for a reference to the derivative fn.
+ //
if (auto origSpecialize = as<IRSpecialize>(origCallable))
{
- specialization = origSpecialize;
- origCallable = origSpecialize->getBase();
+ if (auto jvpSpecRefDecorator = origSpecialize->findDecoration<IRJVPDerivativeReferenceDecoration>())
+ {
+ jvpCallable = jvpSpecRefDecorator->getJVPFunc();
+ }
}
- // We should have either a generic or a function reference on our hands.
- SLANG_ASSERT(as<IRGeneric>(origCallable) || as<IRFunc>(origCallable));
-
- // Resolve the derivative function.
+ // Resolve the derivative function for an IRJVPDifferentiate(IRFunc)
//
// Check for the 'JVPDerivativeReference' decorator on the
// base function.
+ //
if (auto jvpRefDecorator = origCallable->findDecoration<IRJVPDerivativeReferenceDecoration>())
{
jvpCallable = jvpRefDecorator->getJVPFunc();
@@ -80,20 +80,9 @@ struct DerivativeCallProcessContext
SLANG_ASSERT(jvpCallable);
- if (specialization)
- {
- // Replace the specialization target with the JVP func.
- specialization->setOperand(0, jvpCallable);
-
- // Then replace the JVPDifferentiate inst with the specialization.
- derivOfInst->replaceUsesWith(specialization);
- }
- else
- {
- // Substitute all uses of the 'derivativeOf' operation
- // with the resolved derivative function.
- derivOfInst->replaceUsesWith(jvpCallable);
- }
+ // Substitute all uses of the 'derivativeOf' operation
+ // with the resolved derivative function.
+ derivOfInst->replaceUsesWith(jvpCallable);
// Remove the 'derivativeOf' inst.
derivOfInst->removeAndDeallocate();
diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp
index b97556ab1..1a86506b3 100644
--- a/source/slang/slang-ir-diff-jvp.cpp
+++ b/source/slang/slang-ir-diff-jvp.cpp
@@ -1091,100 +1091,106 @@ struct JVPTranscriber
// in the current transcription context.
//
InstPair transcribeCall(IRBuilder* builder, IRCall* origCall)
- {
- if (as<IRFunc>(origCall->getCallee()))
- {
- auto origCallee = origCall->getCallee();
+ {
+
+ IRInst* origCallee = origCall->getCallee();
- // Since concrete functions are globals, the primal callee is the same
- // as the original callee.
+ if (!origCallee)
+ {
+ // Note that this can only happen if the callee is a result
+ // of a higher-order operation. For now, we assume that we cannot
+ // differentiate such calls safely.
+ // TODO(sai): Should probably get checked in the front-end.
//
- auto primalCallee = origCallee;
+ getSink()->diagnose(origCall->sourceLoc,
+ Diagnostics::internalCompilerError,
+ "attempting to differentiate unresolved callee");
+
+ return InstPair(nullptr, nullptr);
+ }
- IRInst* diffCallee = nullptr;
+ // Since concrete functions are globals, the primal callee is the same
+ // as the original callee.
+ //
+ auto primalCallee = origCallee;
- if (auto derivativeReferenceDecor = primalCallee->findDecoration<IRJVPDerivativeReferenceDecoration>())
- {
- // If the user has already provided an differentiated implementation, use that.
- diffCallee = derivativeReferenceDecor->getJVPFunc();
- }
- else if (primalCallee->findDecoration<IRJVPDerivativeMarkerDecoration>())
+ IRInst* diffCallee = nullptr;
+
+ if (auto derivativeReferenceDecor = primalCallee->findDecoration<IRJVPDerivativeReferenceDecoration>())
+ {
+ // If the user has already provided an differentiated implementation, use that.
+ diffCallee = derivativeReferenceDecor->getJVPFunc();
+ }
+ else if (primalCallee->findDecoration<IRJVPDerivativeMarkerDecoration>())
+ {
+ // If the function is marked for auto-diff, push a `differentiate` inst for a follow up pass
+ // to generate the implementation.
+ diffCallee = builder->emitJVPDifferentiateInst(
+ differentiateFunctionType(builder, as<IRFuncType>(primalCallee->getFullType())),
+ primalCallee);
+ }
+ else
+ {
+ // The callee is non differentiable, just return primal value with null diff value.
+ IRInst* primalCall = cloneInst(&cloneEnv, builder, origCall);
+ return InstPair(primalCall, nullptr);
+ }
+
+ List<IRInst*> args;
+ // Go over the parameter list and create pairs for each input (if required)
+ for (UIndex ii = 0; ii < origCall->getArgCount(); ii++)
+ {
+ auto origArg = origCall->getArg(ii);
+ auto primalArg = findOrTranscribePrimalInst(builder, origArg);
+ SLANG_ASSERT(primalArg);
+
+ auto primalType = primalArg->getDataType();
+ auto diffArg = findOrTranscribeDiffInst(builder, origArg);
+
+ if (!diffArg)
+ diffArg = getDifferentialZeroOfType(builder, primalType);
+
+ if (auto pairType = tryGetDiffPairType(builder, primalType))
{
- // If the function is marked for auto-diff, push a `differentiate` inst for a follow up pass
- // to generate the implementation.
- diffCallee = builder->emitJVPDifferentiateInst(
- differentiateFunctionType(builder, as<IRFuncType>(primalCallee->getFullType())),
- primalCallee);
+ // If a pair type can be formed, this must be non-null.
+ SLANG_RELEASE_ASSERT(diffArg);
+ auto diffPair = builder->emitMakeDifferentialPair(pairType, primalArg, diffArg);
+ args.add(diffPair);
}
else
{
- // The callee is non differentiable, just return primal value with null diff value.
- IRInst* primalCall = cloneInst(&cloneEnv, builder, origCall);
- return InstPair(primalCall, nullptr);
+ // Add original/primal argument.
+ args.add(primalArg);
}
+ }
+
+ IRType* diffReturnType = nullptr;
+ diffReturnType = tryGetDiffPairType(builder, origCall->getFullType());
- List<IRInst*> args;
- // Go over the parameter list and create pairs for each input (if required)
- for (UIndex ii = 0; ii < origCall->getArgCount(); ii++)
- {
- auto origArg = origCall->getArg(ii);
- auto primalArg = findOrTranscribePrimalInst(builder, origArg);
- SLANG_ASSERT(primalArg);
-
- auto primalType = primalArg->getDataType();
- auto diffArg = findOrTranscribeDiffInst(builder, origArg);
-
- if (!diffArg)
- diffArg = getDifferentialZeroOfType(builder, primalType);
-
- if (auto pairType = tryGetDiffPairType(builder, primalType))
- {
- // If a pair type can be formed, this must be non-null.
- SLANG_RELEASE_ASSERT(diffArg);
- auto diffPair = builder->emitMakeDifferentialPair(pairType, primalArg, diffArg);
- args.add(diffPair);
- }
- else
- {
- // Add original/primal argument.
- args.add(primalArg);
- }
- }
-
- IRType* diffReturnType = nullptr;
- diffReturnType = tryGetDiffPairType(builder, origCall->getFullType());
- SLANG_ASSERT(diffReturnType);
+ if (!diffReturnType)
+ {
+ SLANG_RELEASE_ASSERT(origCall->getFullType()->getOp() == kIROp_VoidType);
+ diffReturnType = builder->getVoidType();
+ }
- auto callInst = builder->emitCallInst(
- diffReturnType,
- diffCallee,
- args);
+ auto callInst = builder->emitCallInst(
+ diffReturnType,
+ diffCallee,
+ args);
+ if (diffReturnType->getOp() != kIROp_VoidType)
+ {
IRInst* primalResultValue = pairBuilder->emitPrimalFieldAccess(builder, callInst);
IRInst* diffResultValue = pairBuilder->emitDiffFieldAccess(builder, callInst);
-
return InstPair(primalResultValue, diffResultValue);
}
- else if(as<IRSpecialize>(origCall->getCallee()) ||
- as<IRLookupWitnessMethod>(origCall->getCallee()))
- {
- getSink()->diagnose(origCall->sourceLoc,
- Diagnostics::unimplemented,
- "attempting to differentiate unspecialized callee or an interface method");
- }
else
{
- // Note that this can only happen if the callee is a result
- // of a higher-order operation. For now, we assume that we cannot
- // differentiate such calls safely.
- // TODO(sai): Should probably get checked in the front-end.
- //
- getSink()->diagnose(origCall->sourceLoc,
- Diagnostics::internalCompilerError,
- "attempting to differentiate unresolved callee");
+ // Return the inst itself if the return value is void.
+ // This is fine since these values should never actually be used anywhere.
+ //
+ return InstPair(callInst, callInst);
}
-
- return InstPair(nullptr, nullptr);
}
InstPair transcribeSwizzle(IRBuilder* builder, IRSwizzle* origSwizzle)
@@ -1314,17 +1320,35 @@ struct JVPTranscriber
return InstPair(nullptr, nullptr);
}
- InstPair transcribeSpecialize(IRBuilder* builder, IRSpecialize* origSpecialize)
+ InstPair transcribeSpecialize(IRBuilder*, IRSpecialize* origSpecialize)
{
- // This is slightly counter-intuitive, but we don't perform any differentiation
- // logic here. We simple clone the original specialize which points to the original function,
- // or the cloned version in case we're inside a generic scope.
- // The differentiation logic is inserted later when this is used in an IRCall.
- // This decision is mostly to maintain a uniform convention of JVPDifferentiate(Specialize(Fn))
- // rather than have Specialize(JVPDifferentiate(Fn))
- //
- auto diffSpecialize = cloneInst(&cloneEnv, builder, origSpecialize);
- return InstPair(diffSpecialize, diffSpecialize);
+ // In general, we should not see any specialize insts at this stage.
+ // The exceptions are target intrinsics.
+ auto genericInnerVal = findInnerMostGenericReturnVal(as<IRGeneric>(origSpecialize->getBase()));
+ if (genericInnerVal->findDecoration<IRTargetIntrinsicDecoration>())
+ {
+ // Look for an IRJVPDerivativeReferenceDecoration on the specialize inst.
+ // (Normally, this would be on the inner IRFunc, but in this case only the JVP func
+ // can be specialized, so we put a decoration on the IRSpecialize)
+ //
+ if (auto jvpFuncDecoration = origSpecialize->findDecoration<IRJVPDerivativeReferenceDecoration>())
+ {
+ auto jvpFunc = jvpFuncDecoration->getJVPFunc();
+
+ // Make sure this isn't itself a specialize .
+ SLANG_RELEASE_ASSERT(!as<IRSpecialize>(jvpFunc));
+
+ return InstPair(jvpFunc, jvpFunc);
+ }
+ }
+ else
+ {
+ getSink()->diagnose(origSpecialize->sourceLoc,
+ Diagnostics::unexpected,
+ "should not be attempting to differentiate anything specialized here.");
+ }
+
+ return InstPair(nullptr, nullptr);
}
InstPair transcibeLookupInterfaceMethod(IRBuilder* builder, IRLookupWitnessMethod* origLookup)
@@ -1777,10 +1801,7 @@ struct JVPTranscriber
return transcribeConst(builder, origInst);
case kIROp_Specialize:
- getSink()->diagnose(origInst->sourceLoc,
- Diagnostics::unexpected,
- "should not be attempting to differentiate anything specialized here.");
- return InstPair(nullptr, nullptr);
+ return transcribeSpecialize(builder, as<IRSpecialize>(origInst));
case kIROp_lookup_interface_method:
return transcibeLookupInterfaceMethod(builder, as<IRLookupWitnessMethod>(origInst));
@@ -1972,7 +1993,10 @@ struct JVPDerivativeContext
if (auto specializeInst = as<IRSpecialize>(baseInst))
{
- baseFunction = as<IRGlobalValueWithCode>(specializeInst->getBase());
+ // Certain specialize insts come with a derivative
+ // reference attached. Skip such instructions.
+ //
+ if (lookupJVPReference(specializeInst)) continue;
}
else if (auto globalValWithCode = as<IRGlobalValueWithCode>(baseInst))
{
@@ -2026,6 +2050,12 @@ struct JVPDerivativeContext
{
builder->setInsertBefore(pairType);
+ if (!as<IRType>(pairType->getValueType()))
+ {
+ // Do not handle non-concrete types.
+ return nullptr;
+ }
+
auto diffPairStructType = (&pairBuilderStorage)->getOrCreateDiffPairType(
builder,
pairType->getValueType());
@@ -2054,19 +2084,20 @@ struct JVPDerivativeContext
if (auto makePairInst = as<IRMakeDifferentialPair>(inst))
{
- auto diffPairStructType = lowerPairType(builder, makePairInst->getDataType(), diffContext);
-
- builder->setInsertBefore(makePairInst);
-
- List<IRInst*> operands;
- operands.add(makePairInst->getPrimalValue());
- operands.add(makePairInst->getDifferentialValue());
+ if (auto diffPairStructType = lowerPairType(builder, makePairInst->getDataType(), diffContext))
+ {
+ builder->setInsertBefore(makePairInst);
+
+ List<IRInst*> operands;
+ operands.add(makePairInst->getPrimalValue());
+ operands.add(makePairInst->getDifferentialValue());
- auto makeStructInst = builder->emitMakeStruct((IRType*)(diffPairStructType), operands);
- makePairInst->replaceUsesWith(makeStructInst);
- makePairInst->removeAndDeallocate();
+ auto makeStructInst = builder->emitMakeStruct((IRType*)(diffPairStructType), operands);
+ makePairInst->replaceUsesWith(makeStructInst);
+ makePairInst->removeAndDeallocate();
- return makeStructInst;
+ return makeStructInst;
+ }
}
return nullptr;
@@ -2074,30 +2105,30 @@ struct JVPDerivativeContext
IRInst* lowerPairAccess(IRBuilder* builder, IRInst* inst, DifferentiableTypeConformanceContext* diffContext)
{
-
if (auto getDiffInst = as<IRDifferentialPairGetDifferential>(inst))
{
- lowerPairType(builder, getDiffInst->getBase()->getDataType(), diffContext);
-
- builder->setInsertBefore(getDiffInst);
-
- auto diffFieldExtract = (&pairBuilderStorage)->emitDiffFieldAccess(builder, getDiffInst->getBase());
- getDiffInst->replaceUsesWith(diffFieldExtract);
- getDiffInst->removeAndDeallocate();
-
- return diffFieldExtract;
+ if (lowerPairType(builder, getDiffInst->getBase()->getDataType(), diffContext))
+ {
+ builder->setInsertBefore(getDiffInst);
+
+ auto diffFieldExtract = (&pairBuilderStorage)->emitDiffFieldAccess(builder, getDiffInst->getBase());
+ getDiffInst->replaceUsesWith(diffFieldExtract);
+ getDiffInst->removeAndDeallocate();
+ return diffFieldExtract;
+ }
}
else if (auto getPrimalInst = as<IRDifferentialPairGetPrimal>(inst))
{
- lowerPairType(builder, getPrimalInst->getBase()->getDataType(), diffContext);
-
- builder->setInsertBefore(getPrimalInst);
+ if (lowerPairType(builder, getPrimalInst->getBase()->getDataType(), diffContext))
+ {
+ builder->setInsertBefore(getPrimalInst);
- auto primalFieldExtract = (&pairBuilderStorage)->emitPrimalFieldAccess(builder, getPrimalInst->getBase());
- getPrimalInst->replaceUsesWith(primalFieldExtract);
- getPrimalInst->removeAndDeallocate();
+ auto primalFieldExtract = (&pairBuilderStorage)->emitPrimalFieldAccess(builder, getPrimalInst->getBase());
+ getPrimalInst->replaceUsesWith(primalFieldExtract);
+ getPrimalInst->removeAndDeallocate();
- return primalFieldExtract;
+ return primalFieldExtract;
+ }
}
return nullptr;
@@ -2295,9 +2326,9 @@ bool processJVPDerivativeMarkers(
return changed;
}
-void stripAutoDiffDecorations(IRModule* module)
+void stripAutoDiffDecorationsFromChildren(IRInst* parent)
{
- for (auto inst : module->getGlobalInsts())
+ for (auto inst : parent->getChildren())
{
for (auto decor = inst->getFirstDecoration(); decor; )
{
@@ -2313,7 +2344,18 @@ void stripAutoDiffDecorations(IRModule* module)
}
decor = next;
}
+
+ if (inst->getFirstChild() != nullptr)
+ {
+ stripAutoDiffDecorationsFromChildren(inst);
+ }
}
}
+void stripAutoDiffDecorations(IRModule* module)
+{
+ stripAutoDiffDecorationsFromChildren(module->getModuleInst());
+}
+
+
}
diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp
index 3ef79df28..53ea99a0c 100644
--- a/source/slang/slang-ir-specialize.cpp
+++ b/source/slang/slang-ir-specialize.cpp
@@ -373,7 +373,76 @@ struct SpecializationContext
// represents a definition rather than a declaration.
//
if(!canSpecializeGeneric(genericVal))
+ {
+ // We have to consider a special case here if baseVal is
+ // an intrinsic, and contains a custom differential.
+ // This is a case where the base cannot be specialized since it has
+ // no body, but the custom should be specialized.
+ // A better way to handle this would be to grab a reference to the
+ // appropriate custom differential, if one exists, at checking time
+ // during CheckInvoke() and construct it's specialization args appropriately.
+ //
+ // For now, we will overwrite the specialization args for the differential
+ // using the args for the base.
+ //
+ auto genericReturnVal = findInnerMostGenericReturnVal(genericVal);
+ if (genericReturnVal->findDecoration<IRTargetIntrinsicDecoration>())
+ {
+ if (auto customDiffRef = genericReturnVal->findDecoration<IRJVPDerivativeReferenceDecoration>())
+ {
+ // If we already have a diff func on this specialize, skip.
+ if (auto specDiffRef = specInst->findDecoration<IRJVPDerivativeReferenceDecoration>())
+ {
+ return false;
+ }
+
+ auto specDiffFunc = as<IRSpecialize>(customDiffRef->getJVPFunc());
+
+ // If the base is specialized, the JVP version must be also be a specialized
+ // generic.
+ //
+ SLANG_ASSERT(specDiffFunc);
+
+ // Build specialization arguments from specInst.
+ // Note that if we've reached this point, we can safely assume
+ // that our args are fully specialized/concrete.
+ //
+ UCount argCount = specInst->getArgCount();
+ List<IRInst*> args;
+ for (UIndex ii = 0; ii < argCount; ii++)
+ args.add(specInst->getArg(ii));
+
+ IRBuilder builder(&sharedBuilderStorage);
+
+ // Specialize the custom JVP function type with the original arguments.
+ builder.setInsertInto(module);
+ auto newDiffFuncType = builder.emitSpecializeInst(
+ builder.getTypeKind(),
+ specDiffFunc->getBase()->getDataType(),
+ argCount,
+ args.getBuffer());
+
+ // Specialize the custom JVP function with the original arguments.
+ builder.setInsertBefore(specInst);
+ auto newDiffFunc = builder.emitSpecializeInst(
+ (IRType*) newDiffFuncType,
+ specDiffFunc->getBase(),
+ argCount,
+ args.getBuffer());
+
+ // Add the new spec insts to the list so they get specialized with
+ // the usual logic.
+ //
+ addToWorkList(newDiffFuncType);
+ addToWorkList(newDiffFunc);
+
+ builder.addJVPDerivativeReferenceDecoration(specInst, newDiffFunc);
+
+ return true;
+ }
+ }
return false;
+ }
// Once we know that specialization is possible,
// the actual work is fairly simple.
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 1e58a456e..ff66caa90 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -8204,12 +8204,28 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
getBuilder()->addDecoration(irFunc, kIROp_ForceInlineDecoration);
}
+ // Register the value now, to avoid any possible infinite recursion when lowering CustomJVPAttribute
+ setGlobalValue(context, decl, LoweredValInfo::simple(findOuterMostGeneric(irFunc)));
+
if (auto attr = decl->findModifier<CustomJVPAttribute>())
{
- auto loweredVal = lowerLValueExpr(this->context, attr->funcDeclRef);
+ // TODO(Sai): HACK.. we need to emit a decl-ref to handle this modifier correctly.
+ // If we don't move the cursor to the parent, we sometimes emit supporting
+ // insts into the function body, which shouldn't happen.
+ //
+ subContext->irBuilder->setInsertInto(irFunc->getParent());
+
+ auto diffFuncType = getFuncType(subContext->astBuilder, attr->funcDeclRef->declRef.as<CallableDecl>());
+ auto irDiffFuncType = lowerType(subContext, diffFuncType);
+
+ auto loweredVal = emitDeclRef(subContext, attr->funcDeclRef->declRef, irDiffFuncType);
+
SLANG_ASSERT(loweredVal.flavor == LoweredValInfo::Flavor::Simple);
- IRFunc* jvpFunc = as<IRFunc>(loweredVal.val);
+ IRInst* jvpFunc = loweredVal.val;
getBuilder()->addDecoration(irFunc, kIROp_JVPDerivativeReferenceDecoration, jvpFunc);
+
+ // Reset cursor.
+ subContext->irBuilder->setInsertInto(irFunc);
}
// For convenience, ensure that any additional global