summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-11-09 19:19:17 -0800
committerGitHub <noreply@github.com>2022-11-09 19:19:17 -0800
commit004f6e30b5df3a3df2c26fe5c4a5e78c49f71166 (patch)
treecbc942746bab043da0eb5298993d95f9665dfddf
parentcedd93690c63188cf98e452c9d104cf51aad6c4e (diff)
Add `[ForwardDerivativeOf]` attribute. (#2501)
* Add [ForwardDerivativeOf] attribute. * Fix handling around phi nodes. * Fixes. * Remove IR opcode for ForwardDerivativeOfDecoration. Co-authored-by: Yong He <yhe@nvidia.com>
-rw-r--r--source/slang/core.meta.slang3
-rw-r--r--source/slang/diff.meta.slang126
-rw-r--r--source/slang/slang-ast-modifier.h13
-rw-r--r--source/slang/slang-ast-support-types.cpp20
-rw-r--r--source/slang/slang-ast-support-types.h4
-rw-r--r--source/slang/slang-ast-type.cpp5
-rw-r--r--source/slang/slang-ast-type.h5
-rw-r--r--source/slang/slang-check-decl.cpp95
-rw-r--r--source/slang/slang-check-expr.cpp9
-rw-r--r--source/slang/slang-check-impl.h2
-rw-r--r--source/slang/slang-check-modifier.cpp100
-rw-r--r--source/slang/slang-diagnostic-defs.h7
-rw-r--r--source/slang/slang-ir-diff-jvp.cpp112
-rw-r--r--source/slang/slang-ir-inst-defs.h2
-rw-r--r--source/slang/slang-ir-insts.h4
-rw-r--r--source/slang/slang-ir-link.cpp114
-rw-r--r--source/slang/slang-lower-to-ir.cpp246
-rw-r--r--source/slang/slang-mangle.cpp12
-rw-r--r--source/slang/slang-serialize.h10
-rw-r--r--tests/autodiff/custom-intrinsic-2.slang37
-rw-r--r--tests/autodiff/custom-intrinsic-2.slang.expected.txt6
-rw-r--r--tests/autodiff/dstdlib-vector.slang2
-rw-r--r--tests/autodiff/dstdlib.slang6
23 files changed, 567 insertions, 373 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index a37124bdc..e1eb9c776 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -2859,9 +2859,6 @@ __attributeTarget(InterfaceDecl)
attribute_syntax [Specialize] : SpecializeAttribute;
__attributeTarget(DeclBase)
-attribute_syntax [DerivativeMember(memberName)] : DerivativeMemberAttribute;
-
-__attributeTarget(DeclBase)
attribute_syntax [builtin] : BuiltinAttribute;
__attributeTarget(DeclBase)
diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang
index c95f8e1ac..1f6064983 100644
--- a/source/slang/diff.meta.slang
+++ b/source/slang/diff.meta.slang
@@ -10,6 +10,13 @@ __attributeTarget(FunctionDeclBase)
attribute_syntax [ForwardDerivative(function)] : ForwardDerivativeAttribute;
+__attributeTarget(FunctionDeclBase)
+attribute_syntax [ForwardDerivativeOf(function)] : ForwardDerivativeOfAttribute;
+
+__attributeTarget(DeclBase)
+attribute_syntax [DerivativeMember(memberName)] : DerivativeMemberAttribute;
+
+
/// Pair type that serves to wrap the primal and
/// differential types of an arbitrary type T.
@@ -83,85 +90,46 @@ struct DifferentialPair : IDifferentiable
#define VECTOR_MAP_UNARY(TYPE, COUNT, FUNC, VALUE) \
vector<TYPE,COUNT> result; for(int i = 0; i < COUNT; ++i) { result[i] = FUNC(VALUE[i]); } return result
-namespace dstd
+// Natural Exponent
+
+__generic<T : __BuiltinFloatingPointType>
+[ForwardDerivativeOf(exp)]
+DifferentialPair<T> __d_exp(DifferentialPair<T> dpx)
{
- // Natural Exponent
- __generic<T : __BuiltinFloatingPointType>
- __target_intrinsic(hlsl)
- __target_intrinsic(glsl)
- __target_intrinsic(cuda, "$P_exp($0)")
- __target_intrinsic(cpp, "$P_exp($0)")
- __target_intrinsic(spirv_direct, "12 resultType resultId glsl450 27 _0")
- [ForwardDerivative(d_exp<T>)]
- T exp(T x);
-
- __generic<T : __BuiltinFloatingPointType>
- DifferentialPair<T> d_exp(DifferentialPair<T> dpx)
- {
- return DifferentialPair<T>(
- dstd.exp(dpx.p),
- T.dmul(dstd.exp(dpx.p), dpx.d));
- }
-
- // Sine
- __generic<T : __BuiltinFloatingPointType>
- __target_intrinsic(hlsl)
- __target_intrinsic(glsl)
- __target_intrinsic(cuda, "$P_sin($0)")
- __target_intrinsic(cpp, "$P_sin($0)")
- __target_intrinsic(spirv_direct, "12 resultType resultId glsl450 13 _0")
- [ForwardDerivative(d_sin<T>)]
- T sin(T x);
-
- __generic<T : __BuiltinFloatingPointType>
- DifferentialPair<T> d_sin(DifferentialPair<T> dpx)
- {
- return DifferentialPair<T>(
- dstd.sin(dpx.p),
- T.dmul(dstd.cos(dpx.p), dpx.d));
- }
-
- // Cosine
- __generic<T : __BuiltinFloatingPointType>
- __target_intrinsic(hlsl)
- __target_intrinsic(glsl)
- __target_intrinsic(cuda, "$P_cos($0)")
- __target_intrinsic(cpp, "$P_cos($0)")
- __target_intrinsic(spirv_direct, "12 resultType resultId glsl450 14 _0")
- [ForwardDerivative(d_cos<T>)]
- T cos(T x);
-
- __generic<T : __BuiltinFloatingPointType>
- DifferentialPair<T> d_cos(DifferentialPair<T> dpx)
- {
- return DifferentialPair<T>(
- dstd.cos(dpx.p),
- T.dmul(-dstd.sin(dpx.p), dpx.d));
- }
-
- __generic<let N : int>
- __target_intrinsic(hlsl)
- __target_intrinsic(glsl)
- __target_intrinsic(spirv_direct, "12 resultType resultId glsl450 27 _0")
- [ForwardDerivative(d_exp_vector)]
- vector<float, N> exp(vector<float, N> x)
- {
- VECTOR_MAP_UNARY(float, N, dstd.exp, x);
- }
-
- __generic<let N : int>
- DifferentialPair<vector<float, N>> d_exp_vector(DifferentialPair<vector<float, N>> dpx)
- {
- vector<float, N> result;
- vector<float, N>.Differential d_result;
- for(int i = 0; i < N; ++i)
- {
- DifferentialPair<float> dpexp = dstd.d_exp(DifferentialPair<float>(dpx.p[i], dpx.d[i]));
- result[i] = dpexp.p;
- d_result[i] = dpexp.d;
- }
-
- return DifferentialPair<vector<float, N>>(result, d_result);
+ return DifferentialPair<T>(
+ exp(dpx.p),
+ T.dmul(exp(dpx.p), dpx.d));
+}
+
+__generic<T:__BuiltinFloatingPointType, let N : int>
+[ForwardDerivativeOf(exp)]
+DifferentialPair<vector<T, N>> __d_exp_vector(DifferentialPair<vector<T, N>> dpx)
+{
+ vector<T, N> result;
+ vector<T, N>.Differential d_result;
+ for(int i = 0; i < N; ++i)
+ {
+ DifferentialPair<T> dpexp = __d_exp(DifferentialPair<T>(dpx.p[i], __slang_noop_cast<T.Differential>(dpx.d[i])));
+ result[i] = dpexp.p;
+ d_result[i] = __slang_noop_cast<T>(dpexp.d);
}
+ return DifferentialPair<vector<T, N>>(result, d_result);
+}
-};
+__generic<T : __BuiltinFloatingPointType>
+[ForwardDerivativeOf(sin)]
+DifferentialPair<T> d_sin(DifferentialPair<T> dpx)
+{
+ return DifferentialPair<T>(
+ sin(dpx.p),
+ T.dmul(cos(dpx.p), dpx.d));
+}
+
+__generic<T : __BuiltinFloatingPointType>
+[ForwardDerivativeOf(cos)]
+DifferentialPair<T> d_cos(DifferentialPair<T> dpx)
+{
+ return DifferentialPair<T>(
+ cos(dpx.p),
+ T.dmul(-sin(dpx.p), dpx.d));
+}
diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h
index 57dfbac9e..d6a961328 100644
--- a/source/slang/slang-ast-modifier.h
+++ b/source/slang/slang-ast-modifier.h
@@ -1031,7 +1031,18 @@ class ForwardDerivativeAttribute : public DifferentiableAttribute
{
SLANG_AST_CLASS(ForwardDerivativeAttribute)
- DeclRefExpr* funcDeclRef;
+ Expr* funcExpr;
+};
+
+ /// The `[ForwardDerivativeOf(primalFunction)]` attribute marks the decorated function as custom
+ /// derivative implementation for `primalFunction`.
+class ForwardDerivativeOfAttribute : public Attribute
+{
+ SLANG_AST_CLASS(ForwardDerivativeOfAttribute)
+
+ Expr* funcExpr;
+
+ Expr* backDeclRef; // DeclRef to this derivative function when initiated from primalFunction.
};
/// Indicates that the modified declaration is one of the "magic" declarations
diff --git a/source/slang/slang-ast-support-types.cpp b/source/slang/slang-ast-support-types.cpp
index 7133f2a65..1f30e0238 100644
--- a/source/slang/slang-ast-support-types.cpp
+++ b/source/slang/slang-ast-support-types.cpp
@@ -14,4 +14,24 @@ QualType::QualType(Type* type)
}
}
+void removeModifier(ModifiableSyntaxNode* syntax, Modifier* toRemove)
+{
+ Modifier* prev = nullptr;
+ for (auto modifier = syntax->modifiers.first; modifier; modifier = modifier->next)
+ {
+ if (modifier == toRemove)
+ {
+ if (prev)
+ {
+ prev->next = modifier->next;
+ }
+ else
+ {
+ syntax->modifiers.first = syntax->modifiers.first->next;
+ }
+ break;
+ }
+ prev = modifier;
+ }
+}
}
diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h
index d4a781846..89ae0da7d 100644
--- a/source/slang/slang-ast-support-types.h
+++ b/source/slang/slang-ast-support-types.h
@@ -474,6 +474,10 @@ namespace Slang
ModifiableSyntaxNode* syntax,
Modifier* modifier);
+ void removeModifier(
+ ModifiableSyntaxNode* syntax,
+ Modifier* modifier);
+
struct QualType
{
SLANG_VALUE_CLASS(QualType)
diff --git a/source/slang/slang-ast-type.cpp b/source/slang/slang-ast-type.cpp
index a869c95a7..ba033c3ad 100644
--- a/source/slang/slang-ast-type.cpp
+++ b/source/slang/slang-ast-type.cpp
@@ -473,6 +473,11 @@ Type* NamespaceType::_createCanonicalTypeOverride()
return this;
}
+Type* DifferentialPairType::getPrimalType()
+{
+ return as<Type>(findInnerMostGenericSubstitution(declRef.substitutions)->getArgs()[0]);
+}
+
// !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! PtrTypeBase !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
diff --git a/source/slang/slang-ast-type.h b/source/slang/slang-ast-type.h
index c7ce21cb0..d9829c4ca 100644
--- a/source/slang/slang-ast-type.h
+++ b/source/slang/slang-ast-type.h
@@ -463,10 +463,7 @@ protected:
class DifferentialPairType : public ArithmeticExpressionType
{
SLANG_AST_CLASS(DifferentialPairType)
-
- // The type of vector elements.
- // As an invariant, this should be a basic type or an alias.
- Type* baseType = nullptr;
+ Type* getPrimalType();
};
class DifferentiableType : public BuiltinType
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 333e9d973..b33c33e7a 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -256,6 +256,11 @@ namespace Slang
void visitParamDecl(ParamDecl* paramDecl);
void _maybeRegisterDifferentialBottomTypeConformance(SemanticsContext& context);
+
+ void checkDerivativeOfAttribute(FunctionDeclBase* funcDecl);
+
+ void checkDerivativeAttribute(FunctionDeclBase* funcDecl, ForwardDerivativeAttribute* attr);
+
};
/// Should the given `decl` nested in `parentDecl` be treated as a static rather than instance declaration?
@@ -4582,11 +4587,101 @@ namespace Slang
}
}
+ void SemanticsDeclBodyVisitor::checkDerivativeOfAttribute(FunctionDeclBase* funcDecl)
+ {
+ auto attr = funcDecl->findModifier<ForwardDerivativeOfAttribute>();
+ if (!attr)
+ return;
+
+ List<Expr*> imaginaryArguments;
+ for (auto param : funcDecl->getParameters())
+ {
+ auto arg = m_astBuilder->create<VarExpr>();
+ arg->declRef.decl = param;
+ arg->type.isLeftValue = param->findModifier<OutModifier>() ? true : false;
+ arg->type.type = param->getType();
+ arg->loc = attr->loc;
+ if (auto pairType = as<DifferentialPairType>(param->getType()))
+ {
+ arg->type.type = pairType->getPrimalType();
+ }
+ imaginaryArguments.add(arg);
+ }
+ auto invokeExpr = constructUncheckedInvokeExpr(attr->funcExpr, imaginaryArguments);
+ auto resolved = ResolveInvoke(invokeExpr);
+ if (auto resolvedInvoke = as<InvokeExpr>(resolved))
+ {
+ if (auto calleeDeclRef = as<DeclRefExpr>(resolvedInvoke->functionExpr))
+ {
+ if (auto existingModifier = calleeDeclRef->declRef.getDecl()->findModifier<ForwardDerivativeAttribute>())
+ {
+ // The primal function already has a `[ForwardDerivative]` attribute, this is invalid.
+ getSink()->diagnose(attr, Diagnostics::declAlreadyHasAttribute, calleeDeclRef->declRef, "[ForwardDerivative]");
+ getSink()->diagnose(existingModifier->loc, Diagnostics::seeDeclarationOf, calleeDeclRef->declRef.getDecl());
+ }
+ attr->funcExpr = calleeDeclRef;
+ auto fwdDerivativeAttr = m_astBuilder->create<ForwardDerivativeAttribute>();
+ fwdDerivativeAttr->loc = attr->loc;
+ auto outterGeneric = GetOuterGeneric(funcDecl);
+ auto declRef =
+ DeclRef<Decl>((outterGeneric ? (Decl*)outterGeneric : funcDecl), nullptr);
+ auto declRefExpr = ConstructDeclRefExpr(declRef, nullptr, attr->loc, nullptr);
+ declRefExpr->type.type = nullptr;
+ fwdDerivativeAttr->args.add(declRefExpr);
+ fwdDerivativeAttr->funcExpr = declRefExpr;
+ checkDerivativeAttribute(as<FunctionDeclBase>(calleeDeclRef->declRef.getDecl()), fwdDerivativeAttr);
+ attr->backDeclRef = fwdDerivativeAttr->funcExpr;
+ fwdDerivativeAttr->funcExpr = nullptr;
+ return;
+ }
+ }
+ getSink()->diagnose(attr, Diagnostics::invalidCustomDerivative);
+ }
+
+ void SemanticsDeclBodyVisitor::checkDerivativeAttribute(FunctionDeclBase* funcDecl, ForwardDerivativeAttribute* attr)
+ {
+ if (!attr->funcExpr)
+ return;
+ if (attr->funcExpr->type.type)
+ return;
+
+ List<Expr*> imaginaryArguments;
+ for (auto param : funcDecl->getParameters())
+ {
+ auto arg = m_astBuilder->create<VarExpr>();
+ arg->declRef.decl = param;
+ arg->type.isLeftValue = param->findModifier<OutModifier>() ? true : false;
+ arg->type.type = param->getType();
+ arg->loc = attr->loc;
+ if (auto pairType = getDifferentialPairType(param->getType()))
+ {
+ arg->type.type = pairType;
+ }
+ imaginaryArguments.add(arg);
+ }
+ auto invokeExpr = constructUncheckedInvokeExpr(attr->funcExpr, imaginaryArguments);
+ auto resolved = ResolveInvoke(invokeExpr);
+ if (auto resolvedInvoke = as<InvokeExpr>(resolved))
+ {
+ if (auto calleeDeclRef = as<DeclRefExpr>(resolvedInvoke->functionExpr))
+ {
+ attr->funcExpr = calleeDeclRef;
+ return;
+ }
+ }
+ getSink()->diagnose(attr, Diagnostics::invalidCustomDerivative);
+ }
+
void SemanticsDeclBodyVisitor::visitFunctionDeclBase(FunctionDeclBase* decl)
{
auto newContext = withParentFunc(decl);
_maybeRegisterDifferentialBottomTypeConformance(newContext);
+ // Run checking on attributes that can't be fully checked in header checking stage.
+ checkDerivativeOfAttribute(decl);
+ if (auto derivativeAttr = decl->findModifier<ForwardDerivativeAttribute>())
+ checkDerivativeAttribute(decl, derivativeAttr);
+
if (auto body = decl->body)
{
checkStmt(decl->body, newContext);
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 09dd9eea1..30db9ecfa 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -393,6 +393,15 @@ namespace Slang
return derefExpr;
}
+ InvokeExpr* SemanticsVisitor::constructUncheckedInvokeExpr(Expr* callee, const List<Expr*>& arguments)
+ {
+ auto result = m_astBuilder->create<InvokeExpr>();
+ result->loc = callee->loc;
+ result->functionExpr = callee;
+ result->arguments.addRange(arguments);
+ return result;
+ }
+
Expr* SemanticsVisitor::maybeUseSynthesizedDeclForLookupResult(LookupResultItem const& item, Expr* originalExpr)
{
// If the only result from lookup is an entry in an interface decl, it could be that
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index 76918ebbe..70b120518 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -542,6 +542,8 @@ namespace Slang
Expr* base,
SourceLoc loc);
+ InvokeExpr* constructUncheckedInvokeExpr(Expr* callee, const List<Expr*>& arguments);
+
Expr* maybeUseSynthesizedDeclForLookupResult(
LookupResultItem const& item,
Expr* orignalExpr);
diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp
index d8b05198c..b8ac21e2d 100644
--- a/source/slang/slang-check-modifier.cpp
+++ b/source/slang/slang-check-modifier.cpp
@@ -617,92 +617,30 @@ namespace Slang
getSink()->diagnose(diffExpr, Slang::Diagnostics::invalidCustomDerivative, as<Decl>(attrTarget));
return false;
}
-
- // Either diffExpr has a function type, or it is a reference to a generic.
- if (!as<FuncType>(diffExpr->type) &&
- !(as<DeclRefExpr>(diffExpr) &&
- as<DeclRefExpr>(diffExpr)->declRef.as<GenericDecl>().getDecl() != nullptr))
- {
- return false;
- }
-
- auto diffDeclRef = as<DeclRefExpr>(diffExpr)->declRef;
-
- UCount genericLevels = 0;
- // If we've grabbed the outer generic for some reason,
- // recursively construct GenericAppExpr<...>(generic)
- // and check that to get a specialized func.
- //
- while (diffDeclRef.as<GenericDecl>().getDecl() != nullptr)
- {
- // Forward to the inner decl
- diffDeclRef = makeDeclRef(diffDeclRef.as<GenericDecl>().getDecl()->inner);
-
- // Increment counter.
- genericLevels += 1;
- }
-
- auto targetGeneric = as<GenericDecl>(as<Decl>(attrTarget)->parentDecl);
- auto diffGeneric = as<GenericDecl>(diffDeclRef.getDecl()->parentDecl);
- Expr* currentDiffExpr = diffExpr;
-
- // Go back through each level, and use generic declarations in the
- // target's generic scope as arguments for the diff function's generic.
+ // We store the partially checked funcExpr in the attribute, and
+ // rely on `ResolveInvoke` to resolve it to the actual function decl.
+ // The call to `ResolveInvoke` is deferred until we are checking the
+ // body of the function.
//
- for (UIndex ii = 0; ii < genericLevels; ii++)
- {
- // Nest our expression inside a GenericAppExpr
- auto genericAppExpr = getASTBuilder()->create<GenericAppExpr>();
- genericAppExpr->functionExpr = currentDiffExpr;
-
- // Construct references to the generic args in the current scope.
- // TODO: Probably an easier way to do this.
- for (auto member : targetGeneric->members)
- {
- if (auto typeParamDecl = as<GenericTypeParamDecl>(member))
- {
- genericAppExpr->arguments.add(
- ConstructDeclRefExpr(makeDeclRef(typeParamDecl), nullptr, typeParamDecl->loc, nullptr));
- }
- else if (auto valueParamDecl = as<GenericValueParamDecl>(member))
- {
- genericAppExpr->arguments.add(
- ConstructDeclRefExpr(makeDeclRef(valueParamDecl), nullptr, valueParamDecl->loc, nullptr));
- }
- }
-
- // Set our generic-app-expr as the new expr.
- currentDiffExpr = genericAppExpr;
-
- // Peel the generic layer.
- diffGeneric = as<GenericDecl>(diffGeneric->parentDecl);
- targetGeneric = as<GenericDecl>(targetGeneric->parentDecl);
- }
-
- if ((diffGeneric == nullptr && targetGeneric != nullptr) ||
- (targetGeneric == nullptr && diffGeneric != nullptr))
- {
- //getSink()->diagnose(diffDeclRef, Slang::Diagnostics::customDerivativeGenericSignatureMismatch, diffDeclRef, attrTarget);
- SLANG_UNEXPECTED("");
- }
-
- // If we had to change currentDiffExpr, then re-check the expr.
- if (!currentDiffExpr->type)
- {
- currentDiffExpr = CheckTerm(currentDiffExpr);
- }
+ // Set type to null to indicate that this needs expr needs to be further resolved.
+ diffExpr->type.type = nullptr;
+ forwardDerivativeAttr->funcExpr = diffExpr;
+ }
+ else if (auto forwardDerivativeOfAttr = as<ForwardDerivativeOfAttribute>(attr))
+ {
+ SLANG_ASSERT(attr->args.getCount() == 1);
+ SLANG_ASSERT(as<Decl>(attrTarget));
// Ensure that the argument is a reference to a function definition or declaration.
- auto currentDiffDeclRefExpr = as<DeclRefExpr>(currentDiffExpr);
- auto currentDiffDeclRef = currentDiffDeclRefExpr->declRef;
-
- if (!as<FuncType>(GetTypeForDeclRef(currentDiffDeclRef, currentDiffDeclRef.getLoc())))
+ auto primalFunc = CheckTerm(attr->args[0]);
+ if (primalFunc->type == getASTBuilder()->getErrorType())
{
- getSink()->diagnose(currentDiffDeclRef, Slang::Diagnostics::customDerivativeNotAFunction, currentDiffDeclRef);
+ // Could not resolve the term.
+ getSink()->diagnose(primalFunc, Slang::Diagnostics::invalidCustomDerivative, as<Decl>(attrTarget));
+ return false;
}
-
- // TODO: Can possibly just store a DeclRef (no need for DeclRefExpr)
- forwardDerivativeAttr->funcDeclRef = as<DeclRefExpr>(ConstructDeclRefExpr(currentDiffDeclRef, nullptr, currentDiffDeclRefExpr->loc, diffExpr));
+
+ forwardDerivativeOfAttr->funcExpr = primalFunc;
}
else if (auto comInterfaceAttr = as<ComInterfaceAttribute>(attr))
{
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index ffee0622c..5263ac39b 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -341,10 +341,9 @@ DIAGNOSTIC(31141, Error, definitionOfExternDeclMismatchesOriginalDefinition, "`e
DIAGNOSTIC(31142, Error, ambiguousOriginalDefintionOfExternDecl, "`extern` decl '$0' has ambiguous original definitions.")
DIAGNOSTIC(31143, Error, missingOriginalDefintionOfExternDecl, "no original definition found for `extern` decl '$0'.")
-DIAGNOSTIC(31144, Error, customDerivativeNotAFunction, "$0, used as a custom derivative, is not a function")
-DIAGNOSTIC(31145, Error, customDerivativeGenericSignatureMismatch, "cannot use $0 as custom derivative for $1. generic signature does not match")
-DIAGNOSTIC(31146, Error, customDerivativeSignatureMismatch, "cannot use $0 as custom derivative for $1. signature does not match")
-DIAGNOSTIC(31146, Error, invalidCustomDerivative, "unable to resolve custom differential for $0.")
+DIAGNOSTIC(31145, Error, invalidCustomDerivative, "invalid custom derivative attribute.")
+DIAGNOSTIC(31146, Error, declAlreadyHasAttribute, "'$0' already has attribute '$1'.")
+
// Enums
DIAGNOSTIC(32000, Error, invalidEnumTagType, "invalid tag type for 'enum': '$0'")
diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp
index 574db2036..4c7a132d0 100644
--- a/source/slang/slang-ir-diff-jvp.cpp
+++ b/source/slang/slang-ir-diff-jvp.cpp
@@ -847,37 +847,51 @@ struct JVPTranscriber
cloneInst(&cloneEnv, builder, origParam),
nullptr);
}
-
- if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)primalDataType))
+
+ // Is this param a phi node or a function parameter?
+ auto func = as<IRGlobalValueWithCode>(origParam->getParent()->getParent());
+ bool isFuncParam = (func && origParam->getParent() == func->getFirstBlock());
+ if (isFuncParam)
{
- IRInst* diffPairParam = builder->emitParam(diffPairType);
+ if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)primalDataType))
+ {
+ IRInst* diffPairParam = builder->emitParam(diffPairType);
- auto diffPairVarName = makeDiffPairName(origParam);
- if (diffPairVarName.getLength() > 0)
- builder->addNameHintDecoration(diffPairParam, diffPairVarName.getUnownedSlice());
+ auto diffPairVarName = makeDiffPairName(origParam);
+ if (diffPairVarName.getLength() > 0)
+ builder->addNameHintDecoration(diffPairParam, diffPairVarName.getUnownedSlice());
- SLANG_ASSERT(diffPairParam);
+ SLANG_ASSERT(diffPairParam);
- if (auto pairType = as<IRDifferentialPairType>(diffPairParam->getDataType()))
+ if (auto pairType = as<IRDifferentialPairType>(diffPairParam->getDataType()))
+ {
+ return InstPair(
+ builder->emitDifferentialPairGetPrimal(diffPairParam),
+ builder->emitDifferentialPairGetDifferential(
+ (IRType*)pairBuilder->getDiffTypeFromPairType(builder, pairType),
+ diffPairParam));
+ }
+ // If this is an `in/inout DifferentialPair<>` parameter, we can't produce
+ // its primal and diff parts right now because they would represent a reference
+ // to a pair field, which doesn't make sense since pair types are considered mutable.
+ // We encode the result as if the param is non-differentiable, and handle it
+ // with special care at load/store.
+ return InstPair(diffPairParam, nullptr);
+ }
+ return InstPair(
+ cloneInst(&cloneEnv, builder, origParam),
+ nullptr);
+ }
+ else
+ {
+ auto primal = cloneInst(&cloneEnv, builder, origParam);
+ IRInst* diff = nullptr;
+ if (IRType* diffType = differentiateType(builder, (IRType*)primalDataType))
{
- return InstPair(
- builder->emitDifferentialPairGetPrimal(diffPairParam),
- builder->emitDifferentialPairGetDifferential(
- (IRType*)pairBuilder->getDiffTypeFromPairType(builder, pairType),
- diffPairParam));
+ diff = builder->emitParam(diffType);
}
- // If this is an `in/inout DifferentialPair<>` parameter, we can't produce
- // its primal and diff parts right now because they would represent a reference
- // to a pair field, which doesn't make sense since pair types are considered mutable.
- // We encode the result as if the param is non-differentiable, and handle it
- // with special care at load/store.
- return InstPair(diffPairParam, nullptr);
+ return InstPair(primal, diff);
}
-
-
- return InstPair(
- cloneInst(&cloneEnv, builder, origParam),
- nullptr);
}
// Returns "d<var-name>" to use as a name hint for variables and parameters.
@@ -1313,42 +1327,49 @@ struct JVPTranscriber
switch(origInst->getOp())
{
case kIROp_unconditionalBranch:
+ case kIROp_loop:
auto origBranch = as<IRUnconditionalBranch>(origInst);
// Grab the differentials for any phi nodes.
- List<IRInst*> pairArgs;
+ List<IRInst*> newArgs;
for (UIndex ii = 0; ii < origBranch->getArgCount(); ii++)
{
auto origArg = origBranch->getArg(ii);
+ auto primalArg = lookupPrimalInst(origArg);
+ newArgs.add(primalArg);
- IRInst* pairArg = nullptr;
- if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)origArg->getDataType()))
+ if (differentiateType(builder, primalArg->getDataType()))
{
auto diffArg = lookupDiffInst(origArg, nullptr);
- if (!diffArg)
- {
- diffArg = getDifferentialZeroOfType(builder, (IRType*)origArg->getDataType());
- }
-
- pairArg = builder->emitMakeDifferentialPair(
- diffPairType,
- lookupPrimalInst(origArg),
- diffArg);
- }
- else
- {
- pairArg = lookupPrimalInst(origArg);
+ if (diffArg)
+ newArgs.add(diffArg);
}
- pairArgs.add(pairArg);
}
IRInst* diffBranch = nullptr;
if (auto diffBlock = findOrTranscribeDiffInst(builder, origBranch->getTargetBlock()))
{
- diffBranch = builder->emitBranch(
- as<IRBlock>(diffBlock),
- pairArgs.getCount(),
- pairArgs.getBuffer());
+ if (auto origLoop = as<IRLoop>(origInst))
+ {
+ auto breakBlock = findOrTranscribeDiffInst(builder, origLoop->getBreakBlock());
+ auto continueBlock = findOrTranscribeDiffInst(builder, origLoop->getContinueBlock());
+ List<IRInst*> operands;
+ operands.add(breakBlock);
+ operands.add(continueBlock);
+ operands.addRange(newArgs);
+ diffBranch = builder->emitIntrinsicInst(
+ nullptr,
+ kIROp_loop,
+ operands.getCount(),
+ operands.getBuffer());
+ }
+ else
+ {
+ diffBranch = builder->emitBranch(
+ as<IRBlock>(diffBlock),
+ newArgs.getCount(),
+ newArgs.getBuffer());
+ }
}
// For now, every block in the original fn must have a corresponding
@@ -2517,5 +2538,4 @@ void stripAutoDiffDecorations(IRModule* module)
stripAutoDiffDecorationsFromChildren(module->getModuleInst());
}
-
}
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index 431446f01..cb4854d7d 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -705,7 +705,7 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0)
/// Used by the auto-diff pass to hold a reference to the
/// generated derivative function.
- INST(ForwardDerivativeDecoration, jvpFnReference, 1, 0)
+ INST(ForwardDerivativeDecoration, fwdDerivative, 1, 0)
/// Used by the auto-diff pass to hold a reference to a
/// differential member of a type in its associated differential type.
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 1d1e2ae69..5587a7c68 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -3207,9 +3207,9 @@ public:
addDecoration(value, kIROp_ForwardDifferentiableDecoration);
}
- void addForwardDerivativeDecoration(IRInst* value, IRInst* jvpFn)
+ void addForwardDerivativeDecoration(IRInst* value, IRInst* fwdFunc)
{
- addDecoration(value, kIROp_ForwardDerivativeDecoration, jvpFn);
+ addDecoration(value, kIROp_ForwardDerivativeDecoration, fwdFunc);
}
void addCOMWitnessDecoration(IRInst* value, IRInst* witnessTable)
diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp
index ad4f691f1..cf0293f0d 100644
--- a/source/slang/slang-ir-link.cpp
+++ b/source/slang/slang-ir-link.cpp
@@ -6,7 +6,7 @@
#include "slang-ir-insts.h"
#include "slang-mangle.h"
#include "slang-ir-string-hash.h"
-
+#include "slang-ir-diff-jvp.h"
#include "slang-module-library.h"
#include "../compiler-core/slang-artifact.h"
@@ -412,7 +412,43 @@ IRGlobalVar* cloneGlobalVarImpl(
/// For a given decoration opcode, only one such decoration will ever be copied, and nothing
/// will be copied if the instruction already has a matching decoration (that was cloned
/// from the "best" definition).
- ///
+ ///
+static void cloneExtraDecorationsFromInst(
+ IRSpecContextBase* context,
+ IRBuilder* builder,
+ IRInst* clonedInst,
+ IRInst* originalInst)
+{
+ for (auto decoration : originalInst->getDecorations())
+ {
+ switch (decoration->getOp())
+ {
+ default:
+ break;
+
+ case kIROp_HLSLExportDecoration:
+ case kIROp_BindExistentialSlotsDecoration:
+ case kIROp_LayoutDecoration:
+ case kIROp_PublicDecoration:
+ case kIROp_SequentialIDDecoration:
+ case kIROp_ForwardDerivativeDecoration:
+ if (!clonedInst->findDecorationImpl(decoration->getOp()))
+ {
+ cloneInst(context, builder, decoration);
+ }
+ break;
+ }
+ }
+
+ // We will also copy over source location information from the alternative
+ // values, in case any of them has it available.
+ //
+ if (originalInst->sourceLoc.isValid() && !clonedInst->sourceLoc.isValid())
+ {
+ clonedInst->sourceLoc = originalInst->sourceLoc;
+ }
+}
+
static void cloneExtraDecorations(
IRSpecContextBase* context,
IRInst* clonedInst,
@@ -435,34 +471,7 @@ static void cloneExtraDecorations(
for(auto sym = originalValues.sym; sym; sym = sym->nextWithSameName)
{
- for(auto decoration : sym->irGlobalValue->getDecorations())
- {
- switch(decoration->getOp())
- {
- default:
- break;
-
- case kIROp_HLSLExportDecoration:
- case kIROp_BindExistentialSlotsDecoration:
- case kIROp_LayoutDecoration:
- case kIROp_PublicDecoration:
- case kIROp_SequentialIDDecoration:
- case kIROp_ForwardDerivativeDecoration:
- if(!clonedInst->findDecorationImpl(decoration->getOp()))
- {
- cloneInst(context, builder, decoration);
- }
- break;
- }
- }
-
- // We will also copy over source location information from the alternative
- // values, in case any of them has it available.
- //
- if(sym->irGlobalValue->sourceLoc.isValid() && !clonedInst->sourceLoc.isValid())
- {
- clonedInst->sourceLoc = sym->irGlobalValue->sourceLoc;
- }
+ cloneExtraDecorationsFromInst(context, builder, clonedInst, sym->irGlobalValue);
}
}
@@ -547,6 +556,43 @@ IRGeneric* cloneGenericImpl(
originalVal,
originalValues);
+ // We want to clone extra decorations on the
+ // return value from other symbols as well.
+ auto clonedInnerVal = findGenericReturnVal(clonedVal);
+ for (auto originalSym = originalValues.sym; originalSym;
+ originalSym = originalSym->nextWithSameName.get())
+ {
+ auto originalGeneric = as<IRGeneric>(originalSym->irGlobalValue);
+ if (!originalGeneric)
+ continue;
+ auto originalInnerVal = findGenericReturnVal(originalGeneric);
+
+ // Register all generic parameters before cloning the decorations.
+ auto clonedParam = clonedVal->getFirstParam();
+ auto originalParam = originalGeneric->getFirstParam();
+
+ ShortList<KeyValuePair<IRInst*, IRInst*>> paramMapping;
+ for (; clonedParam && originalParam; (clonedParam = as<IRParam>(clonedParam->next)), (originalParam = as<IRParam>(originalParam->next)))
+ {
+ paramMapping.add(KeyValuePair<IRInst*, IRInst*>(clonedParam, originalParam));
+ }
+ // Generic parameter list does not match, bail.
+ if (clonedParam || originalParam)
+ continue;
+ for (auto kv : paramMapping)
+ {
+ registerClonedValue(context, kv.Key, kv.Value);
+ }
+
+ IRBuilder builderStorage = *builder;
+ IRBuilder* decorBuilder = &builderStorage;
+ decorBuilder->setInsertInto(clonedInnerVal);
+ if (auto firstChild = clonedInnerVal->getFirstChild())
+ {
+ decorBuilder->setInsertBefore(firstChild);
+ }
+ cloneExtraDecorationsFromInst(context, decorBuilder, clonedInnerVal, originalInnerVal);
+ }
return clonedVal;
}
@@ -694,7 +740,6 @@ void cloneGlobalValueWithCodeCommon(
cb = cb->getNextBlock();
}
}
-
}
void checkIRDuplicate(IRInst* inst, IRInst* moduleInst, UnownedStringSlice const& mangledName)
@@ -1405,6 +1450,13 @@ LinkedIR linkIR(
//
List<IRModule*> irModules;
+
+ // Link stdlib modules.
+ auto builtinLinkage = static_cast<Session*>(linkage->getGlobalSession())->getBuiltinLinkage();
+ for (auto& m : builtinLinkage->mapNameToLoadedModules)
+ irModules.add(m.Value->getIRModule());
+
+ // Link modules in the program.
program->enumerateIRModules([&](IRModule* irModule)
{
irModules.add(irModule);
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 12a9f73e6..5930875f1 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -547,6 +547,7 @@ bool isImportedDecl(IRGenContext* context, Decl* decl)
if (!moduleDecl)
return false;
+#if 0
// HACK: don't treat standard library code as
// being imported for right now, just because
// we don't load its IR in the same way as
@@ -557,6 +558,7 @@ bool isImportedDecl(IRGenContext* context, Decl* decl)
// in via the normal means.
if (isFromStdLib(decl))
return false;
+#endif
if (moduleDecl != context->getMainModuleDecl())
return true;
@@ -7782,22 +7784,16 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
return type->getOp() == kIROp_ClassType;
}
- LoweredValInfo lowerFuncDecl(FunctionDeclBase* decl)
+ LoweredValInfo lowerFuncDeclInContext(IRGenContext* subContext, IRBuilder* subBuilder, FunctionDeclBase* decl)
{
- // We are going to use a nested builder, because we will
- // change the parent node that things get nested into.
- //
- NestedContext nestedContextFunc(this);
- auto subBuilder = nestedContextFunc.getBuilder();
- auto subContext = nestedContextFunc.getContext();
auto outerGeneric = emitOuterGenerics(subContext, decl, decl);
// need to create an IR function here
IRFunc* irFunc = subBuilder->createFunc();
- addNameHint(context, irFunc, decl);
- addLinkageDecoration(context, irFunc, decl);
+ addNameHint(subContext, irFunc, decl);
+ addLinkageDecoration(subContext, irFunc, decl);
if (decl->findModifier<ForwardDifferentiableAttribute>())
{
@@ -7868,7 +7864,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
subBuilder->setInsertInto(entryBlock);
UInt paramTypeIndex = 0;
- for( auto paramInfo : parameterLists.params )
+ for (auto paramInfo : parameterLists.params)
{
auto irParamType = paramTypes[paramTypeIndex++];
@@ -7876,91 +7872,91 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
IRParam* irParam = nullptr;
- switch( paramInfo.direction )
+ switch (paramInfo.direction)
{
default:
- {
- // The parameter is being used for input/output purposes,
- // so it will lower to an actual parameter with a pointer type.
- //
- // TODO: Is this the best representation we can use?
+ {
+ // The parameter is being used for input/output purposes,
+ // so it will lower to an actual parameter with a pointer type.
+ //
+ // TODO: Is this the best representation we can use?
- irParam = subBuilder->emitParam(irParamType);
- if(auto paramDecl = paramInfo.decl)
- {
- addVarDecorations(context, irParam, paramDecl);
- subBuilder->addHighLevelDeclDecoration(irParam, paramDecl);
- }
- addParamNameHint(irParam, paramInfo);
+ irParam = subBuilder->emitParam(irParamType);
+ if (auto paramDecl = paramInfo.decl)
+ {
+ addVarDecorations(context, irParam, paramDecl);
+ subBuilder->addHighLevelDeclDecoration(irParam, paramDecl);
+ }
+ addParamNameHint(irParam, paramInfo);
- paramVal = LoweredValInfo::ptr(irParam);
+ paramVal = LoweredValInfo::ptr(irParam);
- // TODO: We might want to copy the pointed-to value into
- // a temporary at the start of the function, and then copy
- // back out at the end, so that we don't have to worry
- // about things like aliasing in the function body.
- //
- // For now we will just use the storage that was passed
- // in by the caller, knowing that our current lowering
- // at call sites will guarantee a fresh/unique location.
- }
- break;
+ // TODO: We might want to copy the pointed-to value into
+ // a temporary at the start of the function, and then copy
+ // back out at the end, so that we don't have to worry
+ // about things like aliasing in the function body.
+ //
+ // For now we will just use the storage that was passed
+ // in by the caller, knowing that our current lowering
+ // at call sites will guarantee a fresh/unique location.
+ }
+ break;
case kParameterDirection_In:
+ {
+ // Simple case of a by-value input parameter.
+ //
+ // We start by declaring an IR parameter of the same type.
+ //
+ auto paramDecl = paramInfo.decl;
+ irParam = subBuilder->emitParam(irParamType);
+ if (paramDecl)
{
- // Simple case of a by-value input parameter.
- //
- // We start by declaring an IR parameter of the same type.
- //
- auto paramDecl = paramInfo.decl;
- irParam = subBuilder->emitParam(irParamType);
- if( paramDecl )
- {
- addVarDecorations(context, irParam, paramDecl);
- subBuilder->addHighLevelDeclDecoration(irParam, paramDecl);
- }
- addParamNameHint(irParam, paramInfo);
- paramVal = LoweredValInfo::simple(irParam);
- //
- // HLSL allows a function parameter to be used as a local
- // variable in the function body (just like C/C++), so
- // we need to support that case as well.
+ addVarDecorations(context, irParam, paramDecl);
+ subBuilder->addHighLevelDeclDecoration(irParam, paramDecl);
+ }
+ addParamNameHint(irParam, paramInfo);
+ paramVal = LoweredValInfo::simple(irParam);
+ //
+ // HLSL allows a function parameter to be used as a local
+ // variable in the function body (just like C/C++), so
+ // we need to support that case as well.
+ //
+ // However, if we notice that the parameter was marked
+ // `const`, then we can skip this step.
+ //
+ // TODO: we should consider having all parameter be implicitly
+ // immutable except in a specific "compatibility mode."
+ //
+ if (paramDecl && paramDecl->findModifier<ConstModifier>())
+ {
+ // This parameter was declared to be immutable,
+ // so there should be no assignment to it in the
+ // function body, and we don't need a temporary.
+ }
+ else
+ {
+ // The parameter migth get used as a temporary in
+ // the function body. We will allocate a mutable
+ // local variable for is value, and then assign
+ // from the parameter to the local at the start
+ // of the function.
//
- // However, if we notice that the parameter was marked
- // `const`, then we can skip this step.
+ auto irLocal = subBuilder->emitVar(irParamType);
+ auto localVal = LoweredValInfo::ptr(irLocal);
+ assign(subContext, localVal, paramVal);
//
- // TODO: we should consider having all parameter be implicitly
- // immutable except in a specific "compatibility mode."
+ // When code later in the body of the function refers
+ // to the parameter declaration, it will actually refer
+ // to the value stored in the local variable.
//
- if(paramDecl && paramDecl->findModifier<ConstModifier>())
- {
- // This parameter was declared to be immutable,
- // so there should be no assignment to it in the
- // function body, and we don't need a temporary.
- }
- else
- {
- // The parameter migth get used as a temporary in
- // the function body. We will allocate a mutable
- // local variable for is value, and then assign
- // from the parameter to the local at the start
- // of the function.
- //
- auto irLocal = subBuilder->emitVar(irParamType);
- auto localVal = LoweredValInfo::ptr(irLocal);
- assign(subContext, localVal, paramVal);
- //
- // When code later in the body of the function refers
- // to the parameter declaration, it will actually refer
- // to the value stored in the local variable.
- //
- paramVal = localVal;
- }
+ paramVal = localVal;
}
- break;
+ }
+ break;
}
- if( auto paramDecl = paramInfo.decl )
+ if (auto paramDecl = paramInfo.decl)
{
setValue(subContext, paramDecl, paramVal);
}
@@ -8008,7 +8004,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
// a local variable to represent this value.
//
auto constructorDecl = as<ConstructorDecl>(decl);
- if(constructorDecl)
+ if (constructorDecl)
{
auto thisVar = subContext->irBuilder->emitVar(irResultType);
subContext->thisVal = LoweredValInfo::ptr(thisVar);
@@ -8031,7 +8027,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
//
if (!subContext->irBuilder->getBlock()->getTerminator())
{
- if(constructorDecl)
+ if (constructorDecl)
{
// A constructor declaration should return the
// value of the `this` variable that was set
@@ -8044,7 +8040,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
subContext->irBuilder->emitReturn(
getSimpleVal(subContext, subContext->thisVal));
}
- else if(as<IRVoidType>(irResultType))
+ else if (as<IRVoidType>(irResultType))
{
// `void`-returning function can get an implicit
// return on exit of the body statement.
@@ -8075,7 +8071,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
// If this declaration was marked as being an intrinsic for a particular
// target, then we should reflect that here.
- for( auto targetMod : decl->getModifiersOfType<SpecializedForTargetModifier>() )
+ for (auto targetMod : decl->getModifiersOfType<SpecializedForTargetModifier>())
{
// `targetMod` indicates that this particular declaration represents
// a specialized definition of the particular function for the given
@@ -8099,11 +8095,11 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
// TODO: We should wrap this an `SpecializedForTargetModifier` together into a single
// case for enumerating the "capabilities" that a declaration requires.
//
- for(auto extensionMod : decl->getModifiersOfType<RequiredGLSLExtensionModifier>())
+ for (auto extensionMod : decl->getModifiersOfType<RequiredGLSLExtensionModifier>())
{
getBuilder()->addRequireGLSLExtensionDecoration(irFunc, extensionMod->extensionNameToken.getContent());
}
- for(auto versionMod : decl->getModifiersOfType<RequiredGLSLVersionModifier>())
+ for (auto versionMod : decl->getModifiersOfType<RequiredGLSLVersionModifier>())
{
getBuilder()->addRequireGLSLVersionDecoration(irFunc, Int(getIntegerLiteralValue(versionMod->versionNumberToken)));
}
@@ -8116,12 +8112,12 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
getBuilder()->addRequireCUDASMVersionDecoration(irFunc, versionMod->version);
}
- if(decl->findModifier<RequiresNVAPIAttribute>())
+ if (decl->findModifier<RequiresNVAPIAttribute>())
{
getBuilder()->addSimpleDecoration<IRRequiresNVAPIDecoration>(irFunc);
}
- if(decl->findModifier<NoInlineAttribute>())
+ if (decl->findModifier<NoInlineAttribute>())
{
getBuilder()->addSimpleDecoration<IRNoInlineDecoration>(irFunc);
}
@@ -8132,13 +8128,13 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
getBuilder()->addDecoration(irFunc, kIROp_InstanceDecoration, intLit);
}
- if(auto attr = decl->findModifier<MaxVertexCountAttribute>())
+ if (auto attr = decl->findModifier<MaxVertexCountAttribute>())
{
IRIntLit* intLit = _getIntLitFromAttribute(getBuilder(), attr);
getBuilder()->addDecoration(irFunc, kIROp_MaxVertexCountDecoration, intLit);
}
- if(auto attr = decl->findModifier<NumThreadsAttribute>())
+ if (auto attr = decl->findModifier<NumThreadsAttribute>())
{
auto builder = getBuilder();
IRType* intType = builder->getIntType();
@@ -8149,10 +8145,10 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
builder->getIntValue(intType, attr->z)
};
- builder->addDecoration(irFunc, kIROp_NumThreadsDecoration, operands, 3);
+ builder->addDecoration(irFunc, kIROp_NumThreadsDecoration, operands, 3);
}
- if(decl->findModifier<ReadNoneAttribute>())
+ if (decl->findModifier<ReadNoneAttribute>())
{
getBuilder()->addSimpleDecoration<IRReadNoneDecoration>(irFunc);
}
@@ -8192,7 +8188,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
getBuilder()->addDecoration(irFunc, kIROp_SPIRVOpDecoration, intLit);
}
- if(decl->findModifier<UnsafeForceInlineEarlyAttribute>())
+ if (decl->findModifier<UnsafeForceInlineEarlyAttribute>())
{
getBuilder()->addDecoration(irFunc, kIROp_UnsafeForceInlineEarlyDecoration);
}
@@ -8207,23 +8203,54 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
if (auto attr = decl->findModifier<ForwardDerivativeAttribute>())
{
- // TODO(Sai): HACK.. we need to emit a decl-ref to handle this modifier correctly.
- // If we don't move the cursor to the parent, we sometimes emit supporting
- // insts into the function body, which shouldn't happen.
- //
- subContext->irBuilder->setInsertInto(irFunc->getParent());
-
- auto diffFuncType = getFuncType(subContext->astBuilder, attr->funcDeclRef->declRef.as<CallableDecl>());
- auto irDiffFuncType = lowerType(subContext, diffFuncType);
+ // We need to lower the decl ref to the custom derivative function to IR.
+ // The IR insts correspond to the decl ref is not part of the function we
+ // are processing. If we emit it directly to within the function, it could
+ // mess up the assumption on the form of the IR (e.g. having non decoration insts
+ // appearing in the middle of decoration insts). so we emit the decl ref to the
+ // function's parent for now.
+
+ subContext->irBuilder->setInsertInto(irFunc->getParent());
- auto loweredVal = emitDeclRef(subContext, attr->funcDeclRef->declRef, irDiffFuncType);
+ auto loweredVal = lowerRValueExpr(subContext, attr->funcExpr);
SLANG_ASSERT(loweredVal.flavor == LoweredValInfo::Flavor::Simple);
IRInst* jvpFunc = loweredVal.val;
getBuilder()->addDecoration(irFunc, kIROp_ForwardDerivativeDecoration, jvpFunc);
// Reset cursor.
- subContext->irBuilder->setInsertInto(irFunc);
+ subContext->irBuilder->setInsertInto(irFunc);
+ }
+
+ if (auto attr = decl->findModifier<ForwardDerivativeOfAttribute>())
+ {
+ if (auto originalDeclRefExpr = as<DeclRefExpr>(attr->funcExpr))
+ {
+ NestedContext originalContextFunc(this);
+ auto originalSubBuilder = originalContextFunc.getBuilder();
+ auto originalSubContext = originalContextFunc.getContext();
+
+ auto originalFuncDecl = as<FunctionDeclBase>(originalDeclRefExpr->declRef.getDecl());
+ SLANG_RELEASE_ASSERT(originalFuncDecl);
+
+ auto originalFuncVal = lowerFuncDeclInContext(originalSubContext, originalSubBuilder, originalFuncDecl).val;
+ if (auto originalFuncGeneric = as<IRGeneric>(originalFuncVal))
+ {
+ originalFuncVal = findGenericReturnVal(originalFuncGeneric);
+ }
+ originalSubBuilder->setInsertBefore(originalFuncVal);
+ auto derivativeFuncVal = lowerRValueExpr(originalSubContext, attr->backDeclRef);
+ originalSubBuilder->addForwardDerivativeDecoration(originalFuncVal, derivativeFuncVal.val);
+ }
+
+ subContext->irBuilder->setInsertInto(irFunc->getParent());
+ auto loweredVal = lowerRValueExpr(subContext, attr->funcExpr);
+
+ SLANG_ASSERT(loweredVal.flavor == LoweredValInfo::Flavor::Simple);
+ IRInst* originalFunc = loweredVal.val;
+ getBuilder()->addDecoration(irFunc, kIROp_ForwardDerivativeDecoration, originalFunc);
+
+ subContext->irBuilder->setInsertInto(irFunc);
}
// For convenience, ensure that any additional global
@@ -8239,6 +8266,17 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
return LoweredValInfo::simple(finalVal);
}
+ LoweredValInfo lowerFuncDecl(FunctionDeclBase* decl)
+ {
+ // We are going to use a nested builder, because we will
+ // change the parent node that things get nested into.
+ //
+ NestedContext nestedContextFunc(this);
+ auto subBuilder = nestedContextFunc.getBuilder();
+ auto subContext = nestedContextFunc.getContext();
+ return lowerFuncDeclInContext(subContext, subBuilder, decl);
+ }
+
LoweredValInfo visitGenericDecl(GenericDecl * genDecl)
{
// TODO: Should this just always visit/lower the inner decl?
diff --git a/source/slang/slang-mangle.cpp b/source/slang/slang-mangle.cpp
index ab1c1ec4a..f88549e41 100644
--- a/source/slang/slang-mangle.cpp
+++ b/source/slang/slang-mangle.cpp
@@ -301,17 +301,7 @@ namespace Slang
auto parentGenericDeclRef = parentDeclRef.as<GenericDecl>();
if( parentDeclRef )
{
- // In certain cases we want to skip emitting the parent
- if(parentGenericDeclRef && (parentGenericDeclRef.getDecl()->inner != declRef.getDecl()))
- {
- }
- else if(parentDeclRef.as<FunctionDeclBase>())
- {
- }
- else
- {
- emitQualifiedName(context, parentDeclRef);
- }
+ emitQualifiedName(context, parentDeclRef);
}
// A generic declaration is kind of a pseudo-declaration
diff --git a/source/slang/slang-serialize.h b/source/slang/slang-serialize.h
index e08d26dd5..581ce2e5f 100644
--- a/source/slang/slang-serialize.h
+++ b/source/slang/slang-serialize.h
@@ -359,8 +359,14 @@ public:
SerialIndex addName(const Name* name);
/// Adding import symbols
- SerialIndex addImportSymbol(const UnownedStringSlice& slice) { return _addStringSlice(SerialTypeKind::ImportSymbol, m_importSymbolMap, slice); }
- SerialIndex addImportSymbol(const String& string){ return _addStringSlice(SerialTypeKind::ImportSymbol, m_importSymbolMap, string.getUnownedSlice()); }
+ SerialIndex addImportSymbol(const UnownedStringSlice& slice)
+ {
+ return _addStringSlice(SerialTypeKind::ImportSymbol, m_importSymbolMap, slice);
+ }
+ SerialIndex addImportSymbol(const String& string)
+ {
+ return _addStringSlice(SerialTypeKind::ImportSymbol, m_importSymbolMap, string.getUnownedSlice());
+ }
/// Set a the ptr associated with an index.
/// NOTE! That there cannot be a pre-existing setting.
diff --git a/tests/autodiff/custom-intrinsic-2.slang b/tests/autodiff/custom-intrinsic-2.slang
new file mode 100644
index 000000000..0a2fd9c0b
--- /dev/null
+++ b/tests/autodiff/custom-intrinsic-2.slang
@@ -0,0 +1,37 @@
+//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
+//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
+
+//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer
+RWStructuredBuffer<float> outputBuffer;
+
+typedef DifferentialPair<float> dpfloat;
+
+float f(float x)
+{
+ return x*x;
+}
+
+[ForwardDerivativeOf(f)]
+DifferentialPair<float> df(DifferentialPair<float> x)
+{
+ var primal = x.p * x.p;
+ var diff = 2 * x.p * x.d;
+ return DifferentialPair<float>(primal, diff);
+}
+
+[ForwardDifferentiable]
+float g(float x)
+{
+ return f(x);
+}
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
+{
+ {
+ dpfloat dpa = dpfloat(3.0, 1.0);
+
+ outputBuffer[0] = f(dpa.p); // Expect: 9.0
+ outputBuffer[1] = __fwd_diff(g)(dpa).d; // Expect: 6.0
+ }
+}
diff --git a/tests/autodiff/custom-intrinsic-2.slang.expected.txt b/tests/autodiff/custom-intrinsic-2.slang.expected.txt
new file mode 100644
index 000000000..5483a4781
--- /dev/null
+++ b/tests/autodiff/custom-intrinsic-2.slang.expected.txt
@@ -0,0 +1,6 @@
+type: float
+9.000000
+6.000000
+0.000000
+0.000000
+0.000000
diff --git a/tests/autodiff/dstdlib-vector.slang b/tests/autodiff/dstdlib-vector.slang
index 1a1bd0dfa..ba66f0756 100644
--- a/tests/autodiff/dstdlib-vector.slang
+++ b/tests/autodiff/dstdlib-vector.slang
@@ -10,7 +10,7 @@ typedef DifferentialPair<float> dpfloat;
float f(float x)
{
float3 vx = float3(x, 2*x, 3*x);
- float3 vexpx = dstd.exp(vx);
+ float3 vexpx = exp(vx);
return vexpx.x + vexpx.y + vexpx.z;
}
diff --git a/tests/autodiff/dstdlib.slang b/tests/autodiff/dstdlib.slang
index 247200511..b96cd3c51 100644
--- a/tests/autodiff/dstdlib.slang
+++ b/tests/autodiff/dstdlib.slang
@@ -9,19 +9,19 @@ typedef DifferentialPair<float> dpfloat;
[ForwardDifferentiable]
float f(float x)
{
- return dstd.exp(x);
+ return exp(x);
}
[ForwardDifferentiable]
float g(float x)
{
- return dstd.sin(x);
+ return sin(x);
}
[ForwardDifferentiable]
float h(float x)
{
- return dstd.cos(x);
+ return cos(x);
}
[numthreads(1, 1, 1)]