diff options
| author | Yong He <yonghe@outlook.com> | 2022-11-09 19:19:17 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-11-09 19:19:17 -0800 |
| commit | 004f6e30b5df3a3df2c26fe5c4a5e78c49f71166 (patch) | |
| tree | cbc942746bab043da0eb5298993d95f9665dfddf /source | |
| parent | cedd93690c63188cf98e452c9d104cf51aad6c4e (diff) | |
Add `[ForwardDerivativeOf]` attribute. (#2501)
* Add [ForwardDerivativeOf] attribute.
* Fix handling around phi nodes.
* Fixes.
* Remove IR opcode for ForwardDerivativeOfDecoration.
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/core.meta.slang | 3 | ||||
| -rw-r--r-- | source/slang/diff.meta.slang | 126 | ||||
| -rw-r--r-- | source/slang/slang-ast-modifier.h | 13 | ||||
| -rw-r--r-- | source/slang/slang-ast-support-types.cpp | 20 | ||||
| -rw-r--r-- | source/slang/slang-ast-support-types.h | 4 | ||||
| -rw-r--r-- | source/slang/slang-ast-type.cpp | 5 | ||||
| -rw-r--r-- | source/slang/slang-ast-type.h | 5 | ||||
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 95 | ||||
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 9 | ||||
| -rw-r--r-- | source/slang/slang-check-impl.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-check-modifier.cpp | 100 | ||||
| -rw-r--r-- | source/slang/slang-diagnostic-defs.h | 7 | ||||
| -rw-r--r-- | source/slang/slang-ir-diff-jvp.cpp | 112 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 4 | ||||
| -rw-r--r-- | source/slang/slang-ir-link.cpp | 114 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 246 | ||||
| -rw-r--r-- | source/slang/slang-mangle.cpp | 12 | ||||
| -rw-r--r-- | source/slang/slang-serialize.h | 10 |
19 files changed, 520 insertions, 369 deletions
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index a37124bdc..e1eb9c776 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -2859,9 +2859,6 @@ __attributeTarget(InterfaceDecl) attribute_syntax [Specialize] : SpecializeAttribute; __attributeTarget(DeclBase) -attribute_syntax [DerivativeMember(memberName)] : DerivativeMemberAttribute; - -__attributeTarget(DeclBase) attribute_syntax [builtin] : BuiltinAttribute; __attributeTarget(DeclBase) diff --git a/source/slang/diff.meta.slang b/source/slang/diff.meta.slang index c95f8e1ac..1f6064983 100644 --- a/source/slang/diff.meta.slang +++ b/source/slang/diff.meta.slang @@ -10,6 +10,13 @@ __attributeTarget(FunctionDeclBase) attribute_syntax [ForwardDerivative(function)] : ForwardDerivativeAttribute; +__attributeTarget(FunctionDeclBase) +attribute_syntax [ForwardDerivativeOf(function)] : ForwardDerivativeOfAttribute; + +__attributeTarget(DeclBase) +attribute_syntax [DerivativeMember(memberName)] : DerivativeMemberAttribute; + + /// Pair type that serves to wrap the primal and /// differential types of an arbitrary type T. @@ -83,85 +90,46 @@ struct DifferentialPair : IDifferentiable #define VECTOR_MAP_UNARY(TYPE, COUNT, FUNC, VALUE) \ vector<TYPE,COUNT> result; for(int i = 0; i < COUNT; ++i) { result[i] = FUNC(VALUE[i]); } return result -namespace dstd +// Natural Exponent + +__generic<T : __BuiltinFloatingPointType> +[ForwardDerivativeOf(exp)] +DifferentialPair<T> __d_exp(DifferentialPair<T> dpx) { - // Natural Exponent - __generic<T : __BuiltinFloatingPointType> - __target_intrinsic(hlsl) - __target_intrinsic(glsl) - __target_intrinsic(cuda, "$P_exp($0)") - __target_intrinsic(cpp, "$P_exp($0)") - __target_intrinsic(spirv_direct, "12 resultType resultId glsl450 27 _0") - [ForwardDerivative(d_exp<T>)] - T exp(T x); - - __generic<T : __BuiltinFloatingPointType> - DifferentialPair<T> d_exp(DifferentialPair<T> dpx) - { - return DifferentialPair<T>( - dstd.exp(dpx.p), - T.dmul(dstd.exp(dpx.p), dpx.d)); - } - - // Sine - __generic<T : __BuiltinFloatingPointType> - __target_intrinsic(hlsl) - __target_intrinsic(glsl) - __target_intrinsic(cuda, "$P_sin($0)") - __target_intrinsic(cpp, "$P_sin($0)") - __target_intrinsic(spirv_direct, "12 resultType resultId glsl450 13 _0") - [ForwardDerivative(d_sin<T>)] - T sin(T x); - - __generic<T : __BuiltinFloatingPointType> - DifferentialPair<T> d_sin(DifferentialPair<T> dpx) - { - return DifferentialPair<T>( - dstd.sin(dpx.p), - T.dmul(dstd.cos(dpx.p), dpx.d)); - } - - // Cosine - __generic<T : __BuiltinFloatingPointType> - __target_intrinsic(hlsl) - __target_intrinsic(glsl) - __target_intrinsic(cuda, "$P_cos($0)") - __target_intrinsic(cpp, "$P_cos($0)") - __target_intrinsic(spirv_direct, "12 resultType resultId glsl450 14 _0") - [ForwardDerivative(d_cos<T>)] - T cos(T x); - - __generic<T : __BuiltinFloatingPointType> - DifferentialPair<T> d_cos(DifferentialPair<T> dpx) - { - return DifferentialPair<T>( - dstd.cos(dpx.p), - T.dmul(-dstd.sin(dpx.p), dpx.d)); - } - - __generic<let N : int> - __target_intrinsic(hlsl) - __target_intrinsic(glsl) - __target_intrinsic(spirv_direct, "12 resultType resultId glsl450 27 _0") - [ForwardDerivative(d_exp_vector)] - vector<float, N> exp(vector<float, N> x) - { - VECTOR_MAP_UNARY(float, N, dstd.exp, x); - } - - __generic<let N : int> - DifferentialPair<vector<float, N>> d_exp_vector(DifferentialPair<vector<float, N>> dpx) - { - vector<float, N> result; - vector<float, N>.Differential d_result; - for(int i = 0; i < N; ++i) - { - DifferentialPair<float> dpexp = dstd.d_exp(DifferentialPair<float>(dpx.p[i], dpx.d[i])); - result[i] = dpexp.p; - d_result[i] = dpexp.d; - } - - return DifferentialPair<vector<float, N>>(result, d_result); + return DifferentialPair<T>( + exp(dpx.p), + T.dmul(exp(dpx.p), dpx.d)); +} + +__generic<T:__BuiltinFloatingPointType, let N : int> +[ForwardDerivativeOf(exp)] +DifferentialPair<vector<T, N>> __d_exp_vector(DifferentialPair<vector<T, N>> dpx) +{ + vector<T, N> result; + vector<T, N>.Differential d_result; + for(int i = 0; i < N; ++i) + { + DifferentialPair<T> dpexp = __d_exp(DifferentialPair<T>(dpx.p[i], __slang_noop_cast<T.Differential>(dpx.d[i]))); + result[i] = dpexp.p; + d_result[i] = __slang_noop_cast<T>(dpexp.d); } + return DifferentialPair<vector<T, N>>(result, d_result); +} -}; +__generic<T : __BuiltinFloatingPointType> +[ForwardDerivativeOf(sin)] +DifferentialPair<T> d_sin(DifferentialPair<T> dpx) +{ + return DifferentialPair<T>( + sin(dpx.p), + T.dmul(cos(dpx.p), dpx.d)); +} + +__generic<T : __BuiltinFloatingPointType> +[ForwardDerivativeOf(cos)] +DifferentialPair<T> d_cos(DifferentialPair<T> dpx) +{ + return DifferentialPair<T>( + cos(dpx.p), + T.dmul(-sin(dpx.p), dpx.d)); +} diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index 57dfbac9e..d6a961328 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -1031,7 +1031,18 @@ class ForwardDerivativeAttribute : public DifferentiableAttribute { SLANG_AST_CLASS(ForwardDerivativeAttribute) - DeclRefExpr* funcDeclRef; + Expr* funcExpr; +}; + + /// The `[ForwardDerivativeOf(primalFunction)]` attribute marks the decorated function as custom + /// derivative implementation for `primalFunction`. +class ForwardDerivativeOfAttribute : public Attribute +{ + SLANG_AST_CLASS(ForwardDerivativeOfAttribute) + + Expr* funcExpr; + + Expr* backDeclRef; // DeclRef to this derivative function when initiated from primalFunction. }; /// Indicates that the modified declaration is one of the "magic" declarations diff --git a/source/slang/slang-ast-support-types.cpp b/source/slang/slang-ast-support-types.cpp index 7133f2a65..1f30e0238 100644 --- a/source/slang/slang-ast-support-types.cpp +++ b/source/slang/slang-ast-support-types.cpp @@ -14,4 +14,24 @@ QualType::QualType(Type* type) } } +void removeModifier(ModifiableSyntaxNode* syntax, Modifier* toRemove) +{ + Modifier* prev = nullptr; + for (auto modifier = syntax->modifiers.first; modifier; modifier = modifier->next) + { + if (modifier == toRemove) + { + if (prev) + { + prev->next = modifier->next; + } + else + { + syntax->modifiers.first = syntax->modifiers.first->next; + } + break; + } + prev = modifier; + } +} } diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h index d4a781846..89ae0da7d 100644 --- a/source/slang/slang-ast-support-types.h +++ b/source/slang/slang-ast-support-types.h @@ -474,6 +474,10 @@ namespace Slang ModifiableSyntaxNode* syntax, Modifier* modifier); + void removeModifier( + ModifiableSyntaxNode* syntax, + Modifier* modifier); + struct QualType { SLANG_VALUE_CLASS(QualType) diff --git a/source/slang/slang-ast-type.cpp b/source/slang/slang-ast-type.cpp index a869c95a7..ba033c3ad 100644 --- a/source/slang/slang-ast-type.cpp +++ b/source/slang/slang-ast-type.cpp @@ -473,6 +473,11 @@ Type* NamespaceType::_createCanonicalTypeOverride() return this; } +Type* DifferentialPairType::getPrimalType() +{ + return as<Type>(findInnerMostGenericSubstitution(declRef.substitutions)->getArgs()[0]); +} + // !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! PtrTypeBase !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!! diff --git a/source/slang/slang-ast-type.h b/source/slang/slang-ast-type.h index c7ce21cb0..d9829c4ca 100644 --- a/source/slang/slang-ast-type.h +++ b/source/slang/slang-ast-type.h @@ -463,10 +463,7 @@ protected: class DifferentialPairType : public ArithmeticExpressionType { SLANG_AST_CLASS(DifferentialPairType) - - // The type of vector elements. - // As an invariant, this should be a basic type or an alias. - Type* baseType = nullptr; + Type* getPrimalType(); }; class DifferentiableType : public BuiltinType diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 333e9d973..b33c33e7a 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -256,6 +256,11 @@ namespace Slang void visitParamDecl(ParamDecl* paramDecl); void _maybeRegisterDifferentialBottomTypeConformance(SemanticsContext& context); + + void checkDerivativeOfAttribute(FunctionDeclBase* funcDecl); + + void checkDerivativeAttribute(FunctionDeclBase* funcDecl, ForwardDerivativeAttribute* attr); + }; /// Should the given `decl` nested in `parentDecl` be treated as a static rather than instance declaration? @@ -4582,11 +4587,101 @@ namespace Slang } } + void SemanticsDeclBodyVisitor::checkDerivativeOfAttribute(FunctionDeclBase* funcDecl) + { + auto attr = funcDecl->findModifier<ForwardDerivativeOfAttribute>(); + if (!attr) + return; + + List<Expr*> imaginaryArguments; + for (auto param : funcDecl->getParameters()) + { + auto arg = m_astBuilder->create<VarExpr>(); + arg->declRef.decl = param; + arg->type.isLeftValue = param->findModifier<OutModifier>() ? true : false; + arg->type.type = param->getType(); + arg->loc = attr->loc; + if (auto pairType = as<DifferentialPairType>(param->getType())) + { + arg->type.type = pairType->getPrimalType(); + } + imaginaryArguments.add(arg); + } + auto invokeExpr = constructUncheckedInvokeExpr(attr->funcExpr, imaginaryArguments); + auto resolved = ResolveInvoke(invokeExpr); + if (auto resolvedInvoke = as<InvokeExpr>(resolved)) + { + if (auto calleeDeclRef = as<DeclRefExpr>(resolvedInvoke->functionExpr)) + { + if (auto existingModifier = calleeDeclRef->declRef.getDecl()->findModifier<ForwardDerivativeAttribute>()) + { + // The primal function already has a `[ForwardDerivative]` attribute, this is invalid. + getSink()->diagnose(attr, Diagnostics::declAlreadyHasAttribute, calleeDeclRef->declRef, "[ForwardDerivative]"); + getSink()->diagnose(existingModifier->loc, Diagnostics::seeDeclarationOf, calleeDeclRef->declRef.getDecl()); + } + attr->funcExpr = calleeDeclRef; + auto fwdDerivativeAttr = m_astBuilder->create<ForwardDerivativeAttribute>(); + fwdDerivativeAttr->loc = attr->loc; + auto outterGeneric = GetOuterGeneric(funcDecl); + auto declRef = + DeclRef<Decl>((outterGeneric ? (Decl*)outterGeneric : funcDecl), nullptr); + auto declRefExpr = ConstructDeclRefExpr(declRef, nullptr, attr->loc, nullptr); + declRefExpr->type.type = nullptr; + fwdDerivativeAttr->args.add(declRefExpr); + fwdDerivativeAttr->funcExpr = declRefExpr; + checkDerivativeAttribute(as<FunctionDeclBase>(calleeDeclRef->declRef.getDecl()), fwdDerivativeAttr); + attr->backDeclRef = fwdDerivativeAttr->funcExpr; + fwdDerivativeAttr->funcExpr = nullptr; + return; + } + } + getSink()->diagnose(attr, Diagnostics::invalidCustomDerivative); + } + + void SemanticsDeclBodyVisitor::checkDerivativeAttribute(FunctionDeclBase* funcDecl, ForwardDerivativeAttribute* attr) + { + if (!attr->funcExpr) + return; + if (attr->funcExpr->type.type) + return; + + List<Expr*> imaginaryArguments; + for (auto param : funcDecl->getParameters()) + { + auto arg = m_astBuilder->create<VarExpr>(); + arg->declRef.decl = param; + arg->type.isLeftValue = param->findModifier<OutModifier>() ? true : false; + arg->type.type = param->getType(); + arg->loc = attr->loc; + if (auto pairType = getDifferentialPairType(param->getType())) + { + arg->type.type = pairType; + } + imaginaryArguments.add(arg); + } + auto invokeExpr = constructUncheckedInvokeExpr(attr->funcExpr, imaginaryArguments); + auto resolved = ResolveInvoke(invokeExpr); + if (auto resolvedInvoke = as<InvokeExpr>(resolved)) + { + if (auto calleeDeclRef = as<DeclRefExpr>(resolvedInvoke->functionExpr)) + { + attr->funcExpr = calleeDeclRef; + return; + } + } + getSink()->diagnose(attr, Diagnostics::invalidCustomDerivative); + } + void SemanticsDeclBodyVisitor::visitFunctionDeclBase(FunctionDeclBase* decl) { auto newContext = withParentFunc(decl); _maybeRegisterDifferentialBottomTypeConformance(newContext); + // Run checking on attributes that can't be fully checked in header checking stage. + checkDerivativeOfAttribute(decl); + if (auto derivativeAttr = decl->findModifier<ForwardDerivativeAttribute>()) + checkDerivativeAttribute(decl, derivativeAttr); + if (auto body = decl->body) { checkStmt(decl->body, newContext); diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 09dd9eea1..30db9ecfa 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -393,6 +393,15 @@ namespace Slang return derefExpr; } + InvokeExpr* SemanticsVisitor::constructUncheckedInvokeExpr(Expr* callee, const List<Expr*>& arguments) + { + auto result = m_astBuilder->create<InvokeExpr>(); + result->loc = callee->loc; + result->functionExpr = callee; + result->arguments.addRange(arguments); + return result; + } + Expr* SemanticsVisitor::maybeUseSynthesizedDeclForLookupResult(LookupResultItem const& item, Expr* originalExpr) { // If the only result from lookup is an entry in an interface decl, it could be that diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 76918ebbe..70b120518 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -542,6 +542,8 @@ namespace Slang Expr* base, SourceLoc loc); + InvokeExpr* constructUncheckedInvokeExpr(Expr* callee, const List<Expr*>& arguments); + Expr* maybeUseSynthesizedDeclForLookupResult( LookupResultItem const& item, Expr* orignalExpr); diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp index d8b05198c..b8ac21e2d 100644 --- a/source/slang/slang-check-modifier.cpp +++ b/source/slang/slang-check-modifier.cpp @@ -617,92 +617,30 @@ namespace Slang getSink()->diagnose(diffExpr, Slang::Diagnostics::invalidCustomDerivative, as<Decl>(attrTarget)); return false; } - - // Either diffExpr has a function type, or it is a reference to a generic. - if (!as<FuncType>(diffExpr->type) && - !(as<DeclRefExpr>(diffExpr) && - as<DeclRefExpr>(diffExpr)->declRef.as<GenericDecl>().getDecl() != nullptr)) - { - return false; - } - - auto diffDeclRef = as<DeclRefExpr>(diffExpr)->declRef; - - UCount genericLevels = 0; - // If we've grabbed the outer generic for some reason, - // recursively construct GenericAppExpr<...>(generic) - // and check that to get a specialized func. - // - while (diffDeclRef.as<GenericDecl>().getDecl() != nullptr) - { - // Forward to the inner decl - diffDeclRef = makeDeclRef(diffDeclRef.as<GenericDecl>().getDecl()->inner); - - // Increment counter. - genericLevels += 1; - } - - auto targetGeneric = as<GenericDecl>(as<Decl>(attrTarget)->parentDecl); - auto diffGeneric = as<GenericDecl>(diffDeclRef.getDecl()->parentDecl); - Expr* currentDiffExpr = diffExpr; - - // Go back through each level, and use generic declarations in the - // target's generic scope as arguments for the diff function's generic. + // We store the partially checked funcExpr in the attribute, and + // rely on `ResolveInvoke` to resolve it to the actual function decl. + // The call to `ResolveInvoke` is deferred until we are checking the + // body of the function. // - for (UIndex ii = 0; ii < genericLevels; ii++) - { - // Nest our expression inside a GenericAppExpr - auto genericAppExpr = getASTBuilder()->create<GenericAppExpr>(); - genericAppExpr->functionExpr = currentDiffExpr; - - // Construct references to the generic args in the current scope. - // TODO: Probably an easier way to do this. - for (auto member : targetGeneric->members) - { - if (auto typeParamDecl = as<GenericTypeParamDecl>(member)) - { - genericAppExpr->arguments.add( - ConstructDeclRefExpr(makeDeclRef(typeParamDecl), nullptr, typeParamDecl->loc, nullptr)); - } - else if (auto valueParamDecl = as<GenericValueParamDecl>(member)) - { - genericAppExpr->arguments.add( - ConstructDeclRefExpr(makeDeclRef(valueParamDecl), nullptr, valueParamDecl->loc, nullptr)); - } - } - - // Set our generic-app-expr as the new expr. - currentDiffExpr = genericAppExpr; - - // Peel the generic layer. - diffGeneric = as<GenericDecl>(diffGeneric->parentDecl); - targetGeneric = as<GenericDecl>(targetGeneric->parentDecl); - } - - if ((diffGeneric == nullptr && targetGeneric != nullptr) || - (targetGeneric == nullptr && diffGeneric != nullptr)) - { - //getSink()->diagnose(diffDeclRef, Slang::Diagnostics::customDerivativeGenericSignatureMismatch, diffDeclRef, attrTarget); - SLANG_UNEXPECTED(""); - } - - // If we had to change currentDiffExpr, then re-check the expr. - if (!currentDiffExpr->type) - { - currentDiffExpr = CheckTerm(currentDiffExpr); - } + // Set type to null to indicate that this needs expr needs to be further resolved. + diffExpr->type.type = nullptr; + forwardDerivativeAttr->funcExpr = diffExpr; + } + else if (auto forwardDerivativeOfAttr = as<ForwardDerivativeOfAttribute>(attr)) + { + SLANG_ASSERT(attr->args.getCount() == 1); + SLANG_ASSERT(as<Decl>(attrTarget)); // Ensure that the argument is a reference to a function definition or declaration. - auto currentDiffDeclRefExpr = as<DeclRefExpr>(currentDiffExpr); - auto currentDiffDeclRef = currentDiffDeclRefExpr->declRef; - - if (!as<FuncType>(GetTypeForDeclRef(currentDiffDeclRef, currentDiffDeclRef.getLoc()))) + auto primalFunc = CheckTerm(attr->args[0]); + if (primalFunc->type == getASTBuilder()->getErrorType()) { - getSink()->diagnose(currentDiffDeclRef, Slang::Diagnostics::customDerivativeNotAFunction, currentDiffDeclRef); + // Could not resolve the term. + getSink()->diagnose(primalFunc, Slang::Diagnostics::invalidCustomDerivative, as<Decl>(attrTarget)); + return false; } - - // TODO: Can possibly just store a DeclRef (no need for DeclRefExpr) - forwardDerivativeAttr->funcDeclRef = as<DeclRefExpr>(ConstructDeclRefExpr(currentDiffDeclRef, nullptr, currentDiffDeclRefExpr->loc, diffExpr)); + + forwardDerivativeOfAttr->funcExpr = primalFunc; } else if (auto comInterfaceAttr = as<ComInterfaceAttribute>(attr)) { diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index ffee0622c..5263ac39b 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -341,10 +341,9 @@ DIAGNOSTIC(31141, Error, definitionOfExternDeclMismatchesOriginalDefinition, "`e DIAGNOSTIC(31142, Error, ambiguousOriginalDefintionOfExternDecl, "`extern` decl '$0' has ambiguous original definitions.") DIAGNOSTIC(31143, Error, missingOriginalDefintionOfExternDecl, "no original definition found for `extern` decl '$0'.") -DIAGNOSTIC(31144, Error, customDerivativeNotAFunction, "$0, used as a custom derivative, is not a function") -DIAGNOSTIC(31145, Error, customDerivativeGenericSignatureMismatch, "cannot use $0 as custom derivative for $1. generic signature does not match") -DIAGNOSTIC(31146, Error, customDerivativeSignatureMismatch, "cannot use $0 as custom derivative for $1. signature does not match") -DIAGNOSTIC(31146, Error, invalidCustomDerivative, "unable to resolve custom differential for $0.") +DIAGNOSTIC(31145, Error, invalidCustomDerivative, "invalid custom derivative attribute.") +DIAGNOSTIC(31146, Error, declAlreadyHasAttribute, "'$0' already has attribute '$1'.") + // Enums DIAGNOSTIC(32000, Error, invalidEnumTagType, "invalid tag type for 'enum': '$0'") diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp index 574db2036..4c7a132d0 100644 --- a/source/slang/slang-ir-diff-jvp.cpp +++ b/source/slang/slang-ir-diff-jvp.cpp @@ -847,37 +847,51 @@ struct JVPTranscriber cloneInst(&cloneEnv, builder, origParam), nullptr); } - - if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)primalDataType)) + + // Is this param a phi node or a function parameter? + auto func = as<IRGlobalValueWithCode>(origParam->getParent()->getParent()); + bool isFuncParam = (func && origParam->getParent() == func->getFirstBlock()); + if (isFuncParam) { - IRInst* diffPairParam = builder->emitParam(diffPairType); + if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)primalDataType)) + { + IRInst* diffPairParam = builder->emitParam(diffPairType); - auto diffPairVarName = makeDiffPairName(origParam); - if (diffPairVarName.getLength() > 0) - builder->addNameHintDecoration(diffPairParam, diffPairVarName.getUnownedSlice()); + auto diffPairVarName = makeDiffPairName(origParam); + if (diffPairVarName.getLength() > 0) + builder->addNameHintDecoration(diffPairParam, diffPairVarName.getUnownedSlice()); - SLANG_ASSERT(diffPairParam); + SLANG_ASSERT(diffPairParam); - if (auto pairType = as<IRDifferentialPairType>(diffPairParam->getDataType())) + if (auto pairType = as<IRDifferentialPairType>(diffPairParam->getDataType())) + { + return InstPair( + builder->emitDifferentialPairGetPrimal(diffPairParam), + builder->emitDifferentialPairGetDifferential( + (IRType*)pairBuilder->getDiffTypeFromPairType(builder, pairType), + diffPairParam)); + } + // If this is an `in/inout DifferentialPair<>` parameter, we can't produce + // its primal and diff parts right now because they would represent a reference + // to a pair field, which doesn't make sense since pair types are considered mutable. + // We encode the result as if the param is non-differentiable, and handle it + // with special care at load/store. + return InstPair(diffPairParam, nullptr); + } + return InstPair( + cloneInst(&cloneEnv, builder, origParam), + nullptr); + } + else + { + auto primal = cloneInst(&cloneEnv, builder, origParam); + IRInst* diff = nullptr; + if (IRType* diffType = differentiateType(builder, (IRType*)primalDataType)) { - return InstPair( - builder->emitDifferentialPairGetPrimal(diffPairParam), - builder->emitDifferentialPairGetDifferential( - (IRType*)pairBuilder->getDiffTypeFromPairType(builder, pairType), - diffPairParam)); + diff = builder->emitParam(diffType); } - // If this is an `in/inout DifferentialPair<>` parameter, we can't produce - // its primal and diff parts right now because they would represent a reference - // to a pair field, which doesn't make sense since pair types are considered mutable. - // We encode the result as if the param is non-differentiable, and handle it - // with special care at load/store. - return InstPair(diffPairParam, nullptr); + return InstPair(primal, diff); } - - - return InstPair( - cloneInst(&cloneEnv, builder, origParam), - nullptr); } // Returns "d<var-name>" to use as a name hint for variables and parameters. @@ -1313,42 +1327,49 @@ struct JVPTranscriber switch(origInst->getOp()) { case kIROp_unconditionalBranch: + case kIROp_loop: auto origBranch = as<IRUnconditionalBranch>(origInst); // Grab the differentials for any phi nodes. - List<IRInst*> pairArgs; + List<IRInst*> newArgs; for (UIndex ii = 0; ii < origBranch->getArgCount(); ii++) { auto origArg = origBranch->getArg(ii); + auto primalArg = lookupPrimalInst(origArg); + newArgs.add(primalArg); - IRInst* pairArg = nullptr; - if (auto diffPairType = tryGetDiffPairType(builder, (IRType*)origArg->getDataType())) + if (differentiateType(builder, primalArg->getDataType())) { auto diffArg = lookupDiffInst(origArg, nullptr); - if (!diffArg) - { - diffArg = getDifferentialZeroOfType(builder, (IRType*)origArg->getDataType()); - } - - pairArg = builder->emitMakeDifferentialPair( - diffPairType, - lookupPrimalInst(origArg), - diffArg); - } - else - { - pairArg = lookupPrimalInst(origArg); + if (diffArg) + newArgs.add(diffArg); } - pairArgs.add(pairArg); } IRInst* diffBranch = nullptr; if (auto diffBlock = findOrTranscribeDiffInst(builder, origBranch->getTargetBlock())) { - diffBranch = builder->emitBranch( - as<IRBlock>(diffBlock), - pairArgs.getCount(), - pairArgs.getBuffer()); + if (auto origLoop = as<IRLoop>(origInst)) + { + auto breakBlock = findOrTranscribeDiffInst(builder, origLoop->getBreakBlock()); + auto continueBlock = findOrTranscribeDiffInst(builder, origLoop->getContinueBlock()); + List<IRInst*> operands; + operands.add(breakBlock); + operands.add(continueBlock); + operands.addRange(newArgs); + diffBranch = builder->emitIntrinsicInst( + nullptr, + kIROp_loop, + operands.getCount(), + operands.getBuffer()); + } + else + { + diffBranch = builder->emitBranch( + as<IRBlock>(diffBlock), + newArgs.getCount(), + newArgs.getBuffer()); + } } // For now, every block in the original fn must have a corresponding @@ -2517,5 +2538,4 @@ void stripAutoDiffDecorations(IRModule* module) stripAutoDiffDecorationsFromChildren(module->getModuleInst()); } - } diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index 431446f01..cb4854d7d 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -705,7 +705,7 @@ INST(HighLevelDeclDecoration, highLevelDecl, 1, 0) /// Used by the auto-diff pass to hold a reference to the /// generated derivative function. - INST(ForwardDerivativeDecoration, jvpFnReference, 1, 0) + INST(ForwardDerivativeDecoration, fwdDerivative, 1, 0) /// Used by the auto-diff pass to hold a reference to a /// differential member of a type in its associated differential type. diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 1d1e2ae69..5587a7c68 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -3207,9 +3207,9 @@ public: addDecoration(value, kIROp_ForwardDifferentiableDecoration); } - void addForwardDerivativeDecoration(IRInst* value, IRInst* jvpFn) + void addForwardDerivativeDecoration(IRInst* value, IRInst* fwdFunc) { - addDecoration(value, kIROp_ForwardDerivativeDecoration, jvpFn); + addDecoration(value, kIROp_ForwardDerivativeDecoration, fwdFunc); } void addCOMWitnessDecoration(IRInst* value, IRInst* witnessTable) diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp index ad4f691f1..cf0293f0d 100644 --- a/source/slang/slang-ir-link.cpp +++ b/source/slang/slang-ir-link.cpp @@ -6,7 +6,7 @@ #include "slang-ir-insts.h" #include "slang-mangle.h" #include "slang-ir-string-hash.h" - +#include "slang-ir-diff-jvp.h" #include "slang-module-library.h" #include "../compiler-core/slang-artifact.h" @@ -412,7 +412,43 @@ IRGlobalVar* cloneGlobalVarImpl( /// For a given decoration opcode, only one such decoration will ever be copied, and nothing /// will be copied if the instruction already has a matching decoration (that was cloned /// from the "best" definition). - /// + /// +static void cloneExtraDecorationsFromInst( + IRSpecContextBase* context, + IRBuilder* builder, + IRInst* clonedInst, + IRInst* originalInst) +{ + for (auto decoration : originalInst->getDecorations()) + { + switch (decoration->getOp()) + { + default: + break; + + case kIROp_HLSLExportDecoration: + case kIROp_BindExistentialSlotsDecoration: + case kIROp_LayoutDecoration: + case kIROp_PublicDecoration: + case kIROp_SequentialIDDecoration: + case kIROp_ForwardDerivativeDecoration: + if (!clonedInst->findDecorationImpl(decoration->getOp())) + { + cloneInst(context, builder, decoration); + } + break; + } + } + + // We will also copy over source location information from the alternative + // values, in case any of them has it available. + // + if (originalInst->sourceLoc.isValid() && !clonedInst->sourceLoc.isValid()) + { + clonedInst->sourceLoc = originalInst->sourceLoc; + } +} + static void cloneExtraDecorations( IRSpecContextBase* context, IRInst* clonedInst, @@ -435,34 +471,7 @@ static void cloneExtraDecorations( for(auto sym = originalValues.sym; sym; sym = sym->nextWithSameName) { - for(auto decoration : sym->irGlobalValue->getDecorations()) - { - switch(decoration->getOp()) - { - default: - break; - - case kIROp_HLSLExportDecoration: - case kIROp_BindExistentialSlotsDecoration: - case kIROp_LayoutDecoration: - case kIROp_PublicDecoration: - case kIROp_SequentialIDDecoration: - case kIROp_ForwardDerivativeDecoration: - if(!clonedInst->findDecorationImpl(decoration->getOp())) - { - cloneInst(context, builder, decoration); - } - break; - } - } - - // We will also copy over source location information from the alternative - // values, in case any of them has it available. - // - if(sym->irGlobalValue->sourceLoc.isValid() && !clonedInst->sourceLoc.isValid()) - { - clonedInst->sourceLoc = sym->irGlobalValue->sourceLoc; - } + cloneExtraDecorationsFromInst(context, builder, clonedInst, sym->irGlobalValue); } } @@ -547,6 +556,43 @@ IRGeneric* cloneGenericImpl( originalVal, originalValues); + // We want to clone extra decorations on the + // return value from other symbols as well. + auto clonedInnerVal = findGenericReturnVal(clonedVal); + for (auto originalSym = originalValues.sym; originalSym; + originalSym = originalSym->nextWithSameName.get()) + { + auto originalGeneric = as<IRGeneric>(originalSym->irGlobalValue); + if (!originalGeneric) + continue; + auto originalInnerVal = findGenericReturnVal(originalGeneric); + + // Register all generic parameters before cloning the decorations. + auto clonedParam = clonedVal->getFirstParam(); + auto originalParam = originalGeneric->getFirstParam(); + + ShortList<KeyValuePair<IRInst*, IRInst*>> paramMapping; + for (; clonedParam && originalParam; (clonedParam = as<IRParam>(clonedParam->next)), (originalParam = as<IRParam>(originalParam->next))) + { + paramMapping.add(KeyValuePair<IRInst*, IRInst*>(clonedParam, originalParam)); + } + // Generic parameter list does not match, bail. + if (clonedParam || originalParam) + continue; + for (auto kv : paramMapping) + { + registerClonedValue(context, kv.Key, kv.Value); + } + + IRBuilder builderStorage = *builder; + IRBuilder* decorBuilder = &builderStorage; + decorBuilder->setInsertInto(clonedInnerVal); + if (auto firstChild = clonedInnerVal->getFirstChild()) + { + decorBuilder->setInsertBefore(firstChild); + } + cloneExtraDecorationsFromInst(context, decorBuilder, clonedInnerVal, originalInnerVal); + } return clonedVal; } @@ -694,7 +740,6 @@ void cloneGlobalValueWithCodeCommon( cb = cb->getNextBlock(); } } - } void checkIRDuplicate(IRInst* inst, IRInst* moduleInst, UnownedStringSlice const& mangledName) @@ -1405,6 +1450,13 @@ LinkedIR linkIR( // List<IRModule*> irModules; + + // Link stdlib modules. + auto builtinLinkage = static_cast<Session*>(linkage->getGlobalSession())->getBuiltinLinkage(); + for (auto& m : builtinLinkage->mapNameToLoadedModules) + irModules.add(m.Value->getIRModule()); + + // Link modules in the program. program->enumerateIRModules([&](IRModule* irModule) { irModules.add(irModule); diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 12a9f73e6..5930875f1 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -547,6 +547,7 @@ bool isImportedDecl(IRGenContext* context, Decl* decl) if (!moduleDecl) return false; +#if 0 // HACK: don't treat standard library code as // being imported for right now, just because // we don't load its IR in the same way as @@ -557,6 +558,7 @@ bool isImportedDecl(IRGenContext* context, Decl* decl) // in via the normal means. if (isFromStdLib(decl)) return false; +#endif if (moduleDecl != context->getMainModuleDecl()) return true; @@ -7782,22 +7784,16 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> return type->getOp() == kIROp_ClassType; } - LoweredValInfo lowerFuncDecl(FunctionDeclBase* decl) + LoweredValInfo lowerFuncDeclInContext(IRGenContext* subContext, IRBuilder* subBuilder, FunctionDeclBase* decl) { - // We are going to use a nested builder, because we will - // change the parent node that things get nested into. - // - NestedContext nestedContextFunc(this); - auto subBuilder = nestedContextFunc.getBuilder(); - auto subContext = nestedContextFunc.getContext(); auto outerGeneric = emitOuterGenerics(subContext, decl, decl); // need to create an IR function here IRFunc* irFunc = subBuilder->createFunc(); - addNameHint(context, irFunc, decl); - addLinkageDecoration(context, irFunc, decl); + addNameHint(subContext, irFunc, decl); + addLinkageDecoration(subContext, irFunc, decl); if (decl->findModifier<ForwardDifferentiableAttribute>()) { @@ -7868,7 +7864,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> subBuilder->setInsertInto(entryBlock); UInt paramTypeIndex = 0; - for( auto paramInfo : parameterLists.params ) + for (auto paramInfo : parameterLists.params) { auto irParamType = paramTypes[paramTypeIndex++]; @@ -7876,91 +7872,91 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> IRParam* irParam = nullptr; - switch( paramInfo.direction ) + switch (paramInfo.direction) { default: - { - // The parameter is being used for input/output purposes, - // so it will lower to an actual parameter with a pointer type. - // - // TODO: Is this the best representation we can use? + { + // The parameter is being used for input/output purposes, + // so it will lower to an actual parameter with a pointer type. + // + // TODO: Is this the best representation we can use? - irParam = subBuilder->emitParam(irParamType); - if(auto paramDecl = paramInfo.decl) - { - addVarDecorations(context, irParam, paramDecl); - subBuilder->addHighLevelDeclDecoration(irParam, paramDecl); - } - addParamNameHint(irParam, paramInfo); + irParam = subBuilder->emitParam(irParamType); + if (auto paramDecl = paramInfo.decl) + { + addVarDecorations(context, irParam, paramDecl); + subBuilder->addHighLevelDeclDecoration(irParam, paramDecl); + } + addParamNameHint(irParam, paramInfo); - paramVal = LoweredValInfo::ptr(irParam); + paramVal = LoweredValInfo::ptr(irParam); - // TODO: We might want to copy the pointed-to value into - // a temporary at the start of the function, and then copy - // back out at the end, so that we don't have to worry - // about things like aliasing in the function body. - // - // For now we will just use the storage that was passed - // in by the caller, knowing that our current lowering - // at call sites will guarantee a fresh/unique location. - } - break; + // TODO: We might want to copy the pointed-to value into + // a temporary at the start of the function, and then copy + // back out at the end, so that we don't have to worry + // about things like aliasing in the function body. + // + // For now we will just use the storage that was passed + // in by the caller, knowing that our current lowering + // at call sites will guarantee a fresh/unique location. + } + break; case kParameterDirection_In: + { + // Simple case of a by-value input parameter. + // + // We start by declaring an IR parameter of the same type. + // + auto paramDecl = paramInfo.decl; + irParam = subBuilder->emitParam(irParamType); + if (paramDecl) { - // Simple case of a by-value input parameter. - // - // We start by declaring an IR parameter of the same type. - // - auto paramDecl = paramInfo.decl; - irParam = subBuilder->emitParam(irParamType); - if( paramDecl ) - { - addVarDecorations(context, irParam, paramDecl); - subBuilder->addHighLevelDeclDecoration(irParam, paramDecl); - } - addParamNameHint(irParam, paramInfo); - paramVal = LoweredValInfo::simple(irParam); - // - // HLSL allows a function parameter to be used as a local - // variable in the function body (just like C/C++), so - // we need to support that case as well. + addVarDecorations(context, irParam, paramDecl); + subBuilder->addHighLevelDeclDecoration(irParam, paramDecl); + } + addParamNameHint(irParam, paramInfo); + paramVal = LoweredValInfo::simple(irParam); + // + // HLSL allows a function parameter to be used as a local + // variable in the function body (just like C/C++), so + // we need to support that case as well. + // + // However, if we notice that the parameter was marked + // `const`, then we can skip this step. + // + // TODO: we should consider having all parameter be implicitly + // immutable except in a specific "compatibility mode." + // + if (paramDecl && paramDecl->findModifier<ConstModifier>()) + { + // This parameter was declared to be immutable, + // so there should be no assignment to it in the + // function body, and we don't need a temporary. + } + else + { + // The parameter migth get used as a temporary in + // the function body. We will allocate a mutable + // local variable for is value, and then assign + // from the parameter to the local at the start + // of the function. // - // However, if we notice that the parameter was marked - // `const`, then we can skip this step. + auto irLocal = subBuilder->emitVar(irParamType); + auto localVal = LoweredValInfo::ptr(irLocal); + assign(subContext, localVal, paramVal); // - // TODO: we should consider having all parameter be implicitly - // immutable except in a specific "compatibility mode." + // When code later in the body of the function refers + // to the parameter declaration, it will actually refer + // to the value stored in the local variable. // - if(paramDecl && paramDecl->findModifier<ConstModifier>()) - { - // This parameter was declared to be immutable, - // so there should be no assignment to it in the - // function body, and we don't need a temporary. - } - else - { - // The parameter migth get used as a temporary in - // the function body. We will allocate a mutable - // local variable for is value, and then assign - // from the parameter to the local at the start - // of the function. - // - auto irLocal = subBuilder->emitVar(irParamType); - auto localVal = LoweredValInfo::ptr(irLocal); - assign(subContext, localVal, paramVal); - // - // When code later in the body of the function refers - // to the parameter declaration, it will actually refer - // to the value stored in the local variable. - // - paramVal = localVal; - } + paramVal = localVal; } - break; + } + break; } - if( auto paramDecl = paramInfo.decl ) + if (auto paramDecl = paramInfo.decl) { setValue(subContext, paramDecl, paramVal); } @@ -8008,7 +8004,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // a local variable to represent this value. // auto constructorDecl = as<ConstructorDecl>(decl); - if(constructorDecl) + if (constructorDecl) { auto thisVar = subContext->irBuilder->emitVar(irResultType); subContext->thisVal = LoweredValInfo::ptr(thisVar); @@ -8031,7 +8027,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // if (!subContext->irBuilder->getBlock()->getTerminator()) { - if(constructorDecl) + if (constructorDecl) { // A constructor declaration should return the // value of the `this` variable that was set @@ -8044,7 +8040,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> subContext->irBuilder->emitReturn( getSimpleVal(subContext, subContext->thisVal)); } - else if(as<IRVoidType>(irResultType)) + else if (as<IRVoidType>(irResultType)) { // `void`-returning function can get an implicit // return on exit of the body statement. @@ -8075,7 +8071,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // If this declaration was marked as being an intrinsic for a particular // target, then we should reflect that here. - for( auto targetMod : decl->getModifiersOfType<SpecializedForTargetModifier>() ) + for (auto targetMod : decl->getModifiersOfType<SpecializedForTargetModifier>()) { // `targetMod` indicates that this particular declaration represents // a specialized definition of the particular function for the given @@ -8099,11 +8095,11 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // TODO: We should wrap this an `SpecializedForTargetModifier` together into a single // case for enumerating the "capabilities" that a declaration requires. // - for(auto extensionMod : decl->getModifiersOfType<RequiredGLSLExtensionModifier>()) + for (auto extensionMod : decl->getModifiersOfType<RequiredGLSLExtensionModifier>()) { getBuilder()->addRequireGLSLExtensionDecoration(irFunc, extensionMod->extensionNameToken.getContent()); } - for(auto versionMod : decl->getModifiersOfType<RequiredGLSLVersionModifier>()) + for (auto versionMod : decl->getModifiersOfType<RequiredGLSLVersionModifier>()) { getBuilder()->addRequireGLSLVersionDecoration(irFunc, Int(getIntegerLiteralValue(versionMod->versionNumberToken))); } @@ -8116,12 +8112,12 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> getBuilder()->addRequireCUDASMVersionDecoration(irFunc, versionMod->version); } - if(decl->findModifier<RequiresNVAPIAttribute>()) + if (decl->findModifier<RequiresNVAPIAttribute>()) { getBuilder()->addSimpleDecoration<IRRequiresNVAPIDecoration>(irFunc); } - if(decl->findModifier<NoInlineAttribute>()) + if (decl->findModifier<NoInlineAttribute>()) { getBuilder()->addSimpleDecoration<IRNoInlineDecoration>(irFunc); } @@ -8132,13 +8128,13 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> getBuilder()->addDecoration(irFunc, kIROp_InstanceDecoration, intLit); } - if(auto attr = decl->findModifier<MaxVertexCountAttribute>()) + if (auto attr = decl->findModifier<MaxVertexCountAttribute>()) { IRIntLit* intLit = _getIntLitFromAttribute(getBuilder(), attr); getBuilder()->addDecoration(irFunc, kIROp_MaxVertexCountDecoration, intLit); } - if(auto attr = decl->findModifier<NumThreadsAttribute>()) + if (auto attr = decl->findModifier<NumThreadsAttribute>()) { auto builder = getBuilder(); IRType* intType = builder->getIntType(); @@ -8149,10 +8145,10 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> builder->getIntValue(intType, attr->z) }; - builder->addDecoration(irFunc, kIROp_NumThreadsDecoration, operands, 3); + builder->addDecoration(irFunc, kIROp_NumThreadsDecoration, operands, 3); } - if(decl->findModifier<ReadNoneAttribute>()) + if (decl->findModifier<ReadNoneAttribute>()) { getBuilder()->addSimpleDecoration<IRReadNoneDecoration>(irFunc); } @@ -8192,7 +8188,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> getBuilder()->addDecoration(irFunc, kIROp_SPIRVOpDecoration, intLit); } - if(decl->findModifier<UnsafeForceInlineEarlyAttribute>()) + if (decl->findModifier<UnsafeForceInlineEarlyAttribute>()) { getBuilder()->addDecoration(irFunc, kIROp_UnsafeForceInlineEarlyDecoration); } @@ -8207,23 +8203,54 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> if (auto attr = decl->findModifier<ForwardDerivativeAttribute>()) { - // TODO(Sai): HACK.. we need to emit a decl-ref to handle this modifier correctly. - // If we don't move the cursor to the parent, we sometimes emit supporting - // insts into the function body, which shouldn't happen. - // - subContext->irBuilder->setInsertInto(irFunc->getParent()); - - auto diffFuncType = getFuncType(subContext->astBuilder, attr->funcDeclRef->declRef.as<CallableDecl>()); - auto irDiffFuncType = lowerType(subContext, diffFuncType); + // We need to lower the decl ref to the custom derivative function to IR. + // The IR insts correspond to the decl ref is not part of the function we + // are processing. If we emit it directly to within the function, it could + // mess up the assumption on the form of the IR (e.g. having non decoration insts + // appearing in the middle of decoration insts). so we emit the decl ref to the + // function's parent for now. + + subContext->irBuilder->setInsertInto(irFunc->getParent()); - auto loweredVal = emitDeclRef(subContext, attr->funcDeclRef->declRef, irDiffFuncType); + auto loweredVal = lowerRValueExpr(subContext, attr->funcExpr); SLANG_ASSERT(loweredVal.flavor == LoweredValInfo::Flavor::Simple); IRInst* jvpFunc = loweredVal.val; getBuilder()->addDecoration(irFunc, kIROp_ForwardDerivativeDecoration, jvpFunc); // Reset cursor. - subContext->irBuilder->setInsertInto(irFunc); + subContext->irBuilder->setInsertInto(irFunc); + } + + if (auto attr = decl->findModifier<ForwardDerivativeOfAttribute>()) + { + if (auto originalDeclRefExpr = as<DeclRefExpr>(attr->funcExpr)) + { + NestedContext originalContextFunc(this); + auto originalSubBuilder = originalContextFunc.getBuilder(); + auto originalSubContext = originalContextFunc.getContext(); + + auto originalFuncDecl = as<FunctionDeclBase>(originalDeclRefExpr->declRef.getDecl()); + SLANG_RELEASE_ASSERT(originalFuncDecl); + + auto originalFuncVal = lowerFuncDeclInContext(originalSubContext, originalSubBuilder, originalFuncDecl).val; + if (auto originalFuncGeneric = as<IRGeneric>(originalFuncVal)) + { + originalFuncVal = findGenericReturnVal(originalFuncGeneric); + } + originalSubBuilder->setInsertBefore(originalFuncVal); + auto derivativeFuncVal = lowerRValueExpr(originalSubContext, attr->backDeclRef); + originalSubBuilder->addForwardDerivativeDecoration(originalFuncVal, derivativeFuncVal.val); + } + + subContext->irBuilder->setInsertInto(irFunc->getParent()); + auto loweredVal = lowerRValueExpr(subContext, attr->funcExpr); + + SLANG_ASSERT(loweredVal.flavor == LoweredValInfo::Flavor::Simple); + IRInst* originalFunc = loweredVal.val; + getBuilder()->addDecoration(irFunc, kIROp_ForwardDerivativeDecoration, originalFunc); + + subContext->irBuilder->setInsertInto(irFunc); } // For convenience, ensure that any additional global @@ -8239,6 +8266,17 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> return LoweredValInfo::simple(finalVal); } + LoweredValInfo lowerFuncDecl(FunctionDeclBase* decl) + { + // We are going to use a nested builder, because we will + // change the parent node that things get nested into. + // + NestedContext nestedContextFunc(this); + auto subBuilder = nestedContextFunc.getBuilder(); + auto subContext = nestedContextFunc.getContext(); + return lowerFuncDeclInContext(subContext, subBuilder, decl); + } + LoweredValInfo visitGenericDecl(GenericDecl * genDecl) { // TODO: Should this just always visit/lower the inner decl? diff --git a/source/slang/slang-mangle.cpp b/source/slang/slang-mangle.cpp index ab1c1ec4a..f88549e41 100644 --- a/source/slang/slang-mangle.cpp +++ b/source/slang/slang-mangle.cpp @@ -301,17 +301,7 @@ namespace Slang auto parentGenericDeclRef = parentDeclRef.as<GenericDecl>(); if( parentDeclRef ) { - // In certain cases we want to skip emitting the parent - if(parentGenericDeclRef && (parentGenericDeclRef.getDecl()->inner != declRef.getDecl())) - { - } - else if(parentDeclRef.as<FunctionDeclBase>()) - { - } - else - { - emitQualifiedName(context, parentDeclRef); - } + emitQualifiedName(context, parentDeclRef); } // A generic declaration is kind of a pseudo-declaration diff --git a/source/slang/slang-serialize.h b/source/slang/slang-serialize.h index e08d26dd5..581ce2e5f 100644 --- a/source/slang/slang-serialize.h +++ b/source/slang/slang-serialize.h @@ -359,8 +359,14 @@ public: SerialIndex addName(const Name* name); /// Adding import symbols - SerialIndex addImportSymbol(const UnownedStringSlice& slice) { return _addStringSlice(SerialTypeKind::ImportSymbol, m_importSymbolMap, slice); } - SerialIndex addImportSymbol(const String& string){ return _addStringSlice(SerialTypeKind::ImportSymbol, m_importSymbolMap, string.getUnownedSlice()); } + SerialIndex addImportSymbol(const UnownedStringSlice& slice) + { + return _addStringSlice(SerialTypeKind::ImportSymbol, m_importSymbolMap, slice); + } + SerialIndex addImportSymbol(const String& string) + { + return _addStringSlice(SerialTypeKind::ImportSymbol, m_importSymbolMap, string.getUnownedSlice()); + } /// Set a the ptr associated with an index. /// NOTE! That there cannot be a pre-existing setting. |
