diff options
| author | Yong He <yonghe@outlook.com> | 2022-11-16 12:17:49 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-11-16 12:17:49 -0800 |
| commit | 801aa3b44254341018a1acbe754f2ce3b0900e2a (patch) | |
| tree | b3066778522edb99bf64c0ac80c91b0b4cb788f8 | |
| parent | 09d8e048d2264d89886cda8e87e8a452d4f913c1 (diff) | |
Clean up type checking of higher order expressions. (#2519)
* Clean up type checking of higher order expressions.
* Replace `goto` with `break` to pacify clang.
* Fix.
* Fixes.
* Fix more tests.
* Fix lowerWitnessTable parameter error.
* Exclude attributes from ast printing.
Co-authored-by: Yong He <yhe@nvidia.com>
27 files changed, 604 insertions, 321 deletions
diff --git a/source/compiler-core/slang-diagnostic-sink.cpp b/source/compiler-core/slang-diagnostic-sink.cpp index 5e58098b0..fa638316c 100644 --- a/source/compiler-core/slang-diagnostic-sink.cpp +++ b/source/compiler-core/slang-diagnostic-sink.cpp @@ -69,7 +69,7 @@ void printDiagnosticArg(StringBuilder& sb, Token const& token) sb << token.getContent(); } -SourceLoc const& getDiagnosticPos(Token const& token) +SourceLoc getDiagnosticPos(Token const& token) { return token.loc; } diff --git a/source/compiler-core/slang-diagnostic-sink.h b/source/compiler-core/slang-diagnostic-sink.h index 466a68827..5131e5194 100644 --- a/source/compiler-core/slang-diagnostic-sink.h +++ b/source/compiler-core/slang-diagnostic-sink.h @@ -109,9 +109,9 @@ void printDiagnosticArg(StringBuilder& sb, RefPtr<T> ptr) printDiagnosticArg(sb, ptr.Ptr()); } -inline SourceLoc const& getDiagnosticPos(SourceLoc const& pos) { return pos; } +inline SourceLoc getDiagnosticPos(SourceLoc const& pos) { return pos; } -SourceLoc const& getDiagnosticPos(Token const& token); +SourceLoc getDiagnosticPos(Token const& token); template<typename T> diff --git a/source/core/core.natvis b/source/core/core.natvis index 08446db8d..d9035e8ba 100644 --- a/source/core/core.natvis +++ b/source/core/core.natvis @@ -100,7 +100,17 @@ </LinkedListItems> </Expand> </Type> - + <Type Name="Slang::OrderedDictionary<*,*>"> + <DisplayString>{{ size={_count} }}</DisplayString> + <Expand> + <LinkedListItems> + <Size>_count</Size> + <HeadPointer>kvPairs.head</HeadPointer> + <NextPointer>next</NextPointer> + <ValueNode>Value</ValueNode> + </LinkedListItems> + </Expand> + </Type> <Type Name="Slang::RefPtr<*>"> <SmartPointer Usage="Minimal">pointer</SmartPointer> <DisplayString Condition="pointer == 0">empty</DisplayString> diff --git a/source/slang/slang-ast-expr.h b/source/slang/slang-ast-expr.h index 86b72e05a..3a99ac15f 100644 --- a/source/slang/slang-ast-expr.h +++ b/source/slang/slang-ast-expr.h @@ -440,13 +440,19 @@ class OpenRefExpr : public Expr class HigherOrderInvokeExpr : public Expr { SLANG_ABSTRACT_AST_CLASS(HigherOrderInvokeExpr) - Expr* baseFunction; + Expr* baseFunction; + List<Name*> newParameterNames; }; +class DifferentiateExpr : public HigherOrderInvokeExpr +{ + SLANG_ABSTRACT_AST_CLASS(DifferentiateExpr) + +}; /// An expression of the form `__fwd_diff(fn)` to access the /// forward-mode derivative version of the function `fn` /// -class ForwardDifferentiateExpr: public HigherOrderInvokeExpr +class ForwardDifferentiateExpr: public DifferentiateExpr { SLANG_AST_CLASS(ForwardDifferentiateExpr) }; @@ -454,7 +460,7 @@ class ForwardDifferentiateExpr: public HigherOrderInvokeExpr /// An expression of the form `__bwd_diff(fn)` to access the /// forward-mode derivative version of the function `fn` /// -class BackwardDifferentiateExpr: public HigherOrderInvokeExpr +class BackwardDifferentiateExpr: public DifferentiateExpr { SLANG_AST_CLASS(BackwardDifferentiateExpr) }; diff --git a/source/slang/slang-ast-modifier.h b/source/slang/slang-ast-modifier.h index 24f614019..04af66b50 100644 --- a/source/slang/slang-ast-modifier.h +++ b/source/slang/slang-ast-modifier.h @@ -1068,7 +1068,9 @@ class ForwardDerivativeAttribute : public DifferentiableAttribute /// The `[ForwardDerivativeOf(primalFunction)]` attribute marks the decorated function as custom /// derivative implementation for `primalFunction`. -class ForwardDerivativeOfAttribute : public Attribute + /// ForwardDerivativeOfAttribute inherits from DifferentiableAttribute because a derivative + /// function itself is considered differentiable. +class ForwardDerivativeOfAttribute : public DifferentiableAttribute { SLANG_AST_CLASS(ForwardDerivativeOfAttribute) diff --git a/source/slang/slang-ast-print.cpp b/source/slang/slang-ast-print.cpp index 830c6bf34..107ec64c5 100644 --- a/source/slang/slang-ast-print.cpp +++ b/source/slang/slang-ast-print.cpp @@ -31,6 +31,7 @@ void ASTPrinter::addType(Type* type) m_builder << "<error>"; return; } + type = type->getCanonicalType(); if (m_optionFlags & OptionFlag::SimplifiedBuiltinType) { if (auto vectorType = as<VectorExpressionType>(type)) @@ -357,9 +358,16 @@ void ASTPrinter::addDeclKindPrefix(Decl* decl) continue; if (as<BuiltinModifier>(modifier)) continue; + if (as<BuiltinRequirementModifier>(modifier)) + continue; if (as<BuiltinTypeModifier>(modifier)) continue; + if (as<SpecializedForTargetModifier>(modifier)) + continue; } + // Don't print out attributes. + if (as<AttributeBase>(modifier)) + continue; m_builder << modifier->getKeywordName()->text << " "; } } @@ -462,7 +470,9 @@ void ASTPrinter::addDeclResultType(const DeclRef<Decl>& inDeclRef) /* static */String ASTPrinter::getDeclSignatureString(DeclRef<Decl> declRef, ASTBuilder* astBuilder) { - ASTPrinter astPrinter(astBuilder); + ASTPrinter astPrinter( + astBuilder, + ASTPrinter::OptionFlag::NoInternalKeywords | ASTPrinter::OptionFlag::SimplifiedBuiltinType); astPrinter.addDeclSignature(declRef); return astPrinter.getString(); } diff --git a/source/slang/slang-ast-substitutions.cpp b/source/slang/slang-ast-substitutions.cpp index 7c6cf1bc6..3bb3f69e1 100644 --- a/source/slang/slang-ast-substitutions.cpp +++ b/source/slang/slang-ast-substitutions.cpp @@ -157,7 +157,7 @@ bool ThisTypeSubstitution::_equalsOverride(Substitutions* subst) HashCode ThisTypeSubstitution::_getHashCodeOverride() const { - return witness->getHashCode(); + return witness->sub->getHashCode(); } } // namespace Slang diff --git a/source/slang/slang-ast-support-types.cpp b/source/slang/slang-ast-support-types.cpp index 1f30e0238..b550eaa68 100644 --- a/source/slang/slang-ast-support-types.cpp +++ b/source/slang/slang-ast-support-types.cpp @@ -1,6 +1,7 @@ #include "slang-ast-support-types.h" #include "slang-ast-base.h" #include "slang-ast-type.h" +#include "slang-ast-expr.h" namespace Slang { @@ -34,4 +35,31 @@ void removeModifier(ModifiableSyntaxNode* syntax, Modifier* toRemove) prev = modifier; } } + +Expr* getInnerMostExprFromHigherOrderExpr(Expr* expr) +{ + HashSet<Expr*> workListSet; + while (auto higherOrder = as<HigherOrderInvokeExpr>(expr)) + { + if (workListSet.Add(higherOrder)) + { + expr = higherOrder->baseFunction; + } + else + { + // Circularity, return null. + return nullptr; + } + } + return expr; +} + +UnownedStringSlice getHigherOrderOperatorName(HigherOrderInvokeExpr* expr) +{ + if (as<ForwardDifferentiateExpr>(expr)) + return UnownedStringSlice("fwd_diff"); + else if (as<BackwardDifferentiateExpr>(expr)) + return UnownedStringSlice("bwd_diff"); + return UnownedStringSlice(); +} } diff --git a/source/slang/slang-ast-support-types.h b/source/slang/slang-ast-support-types.h index 89ae0da7d..7c954987e 100644 --- a/source/slang/slang-ast-support-types.h +++ b/source/slang/slang-ast-support-types.h @@ -64,8 +64,8 @@ namespace Slang void printDiagnosticArg(StringBuilder& sb, Val* val); class SyntaxNode; - SourceLoc const& getDiagnosticPos(SyntaxNode const* syntax); - SourceLoc const& getDiagnosticPos(TypeExp const& typeExp); + SourceLoc getDiagnosticPos(SyntaxNode const* syntax); + SourceLoc getDiagnosticPos(TypeExp const& typeExp); typedef NodeBase* (*SyntaxParseCallback)(Parser* parser, void* userData); @@ -793,7 +793,7 @@ namespace Slang // try to find the concrete decl that satisfies the associatedtype requirement from the // concrete type supplied by ThisTypeSubstittution. Val* _tryLookupConcreteAssociatedTypeFromThisTypeSubst(ASTBuilder* builder, DeclRef<Decl> declRef); - + void _printNestedDecl(const Substitutions* substitutions, Decl* decl, StringBuilder& out); template<typename T> struct DeclRef : DeclRefBase @@ -1446,7 +1446,6 @@ namespace Slang { SLANG_OBJ_CLASS(WitnessTable) - List<KeyValuePair<Decl*, RequirementWitness>> requirementList; RequirementDictionary requirementDictionary; void add(Decl* decl, RequirementWitness const& witness); @@ -1515,6 +1514,13 @@ namespace Slang DMulFunc, ///< The `IDifferentiable.dmul` function requirement }; + /// Get the inner most expr from an higher order expr chain, e.g. `__fwd_diff(__fwd_diff(f))`'s + /// inner most expr is `f`. + Expr* getInnerMostExprFromHigherOrderExpr(Expr* expr); + + /// Get the operator name from the higher order invoke expr. + UnownedStringSlice getHigherOrderOperatorName(HigherOrderInvokeExpr* expr); + } // namespace Slang #endif diff --git a/source/slang/slang-ast-type.cpp b/source/slang/slang-ast-type.cpp index ba033c3ad..76623d01c 100644 --- a/source/slang/slang-ast-type.cpp +++ b/source/slang/slang-ast-type.cpp @@ -495,10 +495,9 @@ Type* OptionalType::getValueType() void NamedExpressionType::_toTextOverride(StringBuilder& out) { - Name* name = declRef.getName(); - if (name) + if (declRef.getDecl()) { - out << name->text; + _printNestedDecl(declRef.substitutions, declRef.getDecl(), out); } } diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index aae741770..ffbc5a841 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -3234,7 +3234,7 @@ namespace Slang // First we need to make sure the associated `Differential` type requirement is satisfied. bool hasDifferentialAssocType = false; - for (auto existingEntry : witnessTable->requirementList) + for (auto existingEntry : witnessTable->requirementDictionary) { if (auto builtinReqAttr = existingEntry.Key->findModifier<BuiltinRequirementModifier>()) { @@ -4678,7 +4678,6 @@ namespace Slang void SemanticsDeclBodyVisitor::visitFunctionDeclBase(FunctionDeclBase* decl) { auto newContext = withParentFunc(decl); - _maybeRegisterDifferentialBottomTypeConformance(newContext); // Run checking on attributes that can't be fully checked in header checking stage. checkDerivativeOfAttribute(decl); @@ -6008,7 +6007,7 @@ namespace Slang // without any additional substitutions. if (extDecl->targetType->equals(type)) { - return extDeclRef; + return createDefaultSubstitutionsIfNeeded(m_astBuilder, this, extDeclRef).as<ExtensionDecl>(); } if (!TryUnifyTypes(constraints, extDecl->targetType.Ptr(), type)) diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index f568dd8df..311a5944b 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -960,23 +960,13 @@ namespace Slang Expr* SemanticsVisitor::CheckTerm(Expr* term) { auto checkedTerm = _CheckTerm(term); - - // Differentiable type checking. - // TODO: This can be super slow. - if (this->m_parentFunc && - this->m_parentFunc->findModifier<ForwardDifferentiableAttribute>()) - { - maybeRegisterDifferentiableType(getASTBuilder(), checkedTerm->type.type); - } - // Differentiable type checking. // TODO: This can be super slow. if (this->m_parentFunc && - this->m_parentFunc->findModifier<BackwardDifferentiableAttribute>()) + this->m_parentFunc->findModifier<DifferentiableAttribute>()) { maybeRegisterDifferentiableType(getASTBuilder(), checkedTerm->type.type); } - return checkedTerm; } @@ -1060,6 +1050,10 @@ namespace Slang { return overloadedExpr->base; } + else if (auto overloadedExpr2 = as<OverloadedExpr2>(expr)) + { + return overloadedExpr2->base; + } return nullptr; } @@ -2009,7 +2003,7 @@ namespace Slang return primalType; } - Type* SemanticsVisitor::processJVPFuncType(FuncType* originalType) + Type* SemanticsVisitor::getForwardDiffFuncType(FuncType* originalType) { // Resolve JVP type here. // Note that this type checking needs to be in sync with @@ -2035,7 +2029,7 @@ namespace Slang return jvpType; } - Type* SemanticsVisitor::processBackwardDiffFuncType(FuncType* originalType) + Type* SemanticsVisitor::getBackwardDiffFuncType(FuncType* originalType) { // Resolve backward diff type here. // Note that this type checking needs to be in sync with @@ -2074,30 +2068,151 @@ namespace Slang return type; } - Expr* SemanticsExprVisitor::visitForwardDifferentiateExpr(ForwardDifferentiateExpr* expr) + struct DifferentiateExprCheckingActions { - // Check/Resolve inner function declaration. - expr->baseFunction = CheckTerm(expr->baseFunction); + virtual DifferentiateExpr* createDifferentiateExpr(SemanticsVisitor* semantics) = 0; + virtual void fillDifferentiateExpr(DifferentiateExpr* resultDiffExpr, SemanticsVisitor* semantics, Expr* funcExpr) = 0; + FuncType* getBaseFunctionType(SemanticsVisitor* semantics, Expr* funcExpr) + { + if (auto funcType = as<FuncType>(funcExpr->type.type)) + return funcType; + auto astBuilder = semantics->getASTBuilder(); + if (auto declRefExpr = as<DeclRefExpr>(funcExpr)) + { + if (auto baseFuncGenericDeclRef = declRefExpr->declRef.as<GenericDecl>()) + { + // Get inner function + DeclRef<Decl> unspecializedInnerRef = DeclRef<Decl>( + getInner(baseFuncGenericDeclRef), + baseFuncGenericDeclRef.substitutions); + auto callableDeclRef = unspecializedInnerRef.as<CallableDecl>(); + if (!callableDeclRef) + return nullptr; + auto funcType = getFuncType(astBuilder, callableDeclRef); + return funcType; + } + } + return nullptr; + } + }; - // For now we only support using higher order expr as callee in an invoke expr. - // The actual type of the higher order function will be derived during resolve invoke. - expr->type = m_astBuilder->getBottomType(); + struct ForwardDifferentiateExprCheckingActions : DifferentiateExprCheckingActions + { + virtual DifferentiateExpr* createDifferentiateExpr(SemanticsVisitor* semantics) override + { + return semantics->getASTBuilder()->create<ForwardDifferentiateExpr>(); + } + void fillDifferentiateExpr(DifferentiateExpr* resultDiffExpr, SemanticsVisitor* semantics, Expr* funcExpr) override + { + resultDiffExpr->baseFunction = funcExpr; + auto baseFuncType = getBaseFunctionType(semantics, funcExpr); + if (!baseFuncType) + { + resultDiffExpr->type = semantics->getASTBuilder()->getErrorType(); + semantics->getSink()->diagnose(funcExpr, Diagnostics::expectedFunction, funcExpr->type.type); + return; + } + resultDiffExpr->type = semantics->getForwardDiffFuncType(baseFuncType); + if (auto declRefExpr = as<DeclRefExpr>(getInnerMostExprFromHigherOrderExpr(funcExpr))) + { + if (auto funcDecl = declRefExpr->declRef.as<CallableDecl>()) + { + for (auto param : funcDecl.getDecl()->getParameters()) + { + resultDiffExpr->newParameterNames.add(param->getName()); + } + } + } + } + }; - return expr; - } + struct BackwardDifferentiateExprCheckingActions : DifferentiateExprCheckingActions + { + virtual DifferentiateExpr* createDifferentiateExpr(SemanticsVisitor* semantics) override + { + return semantics->getASTBuilder()->create<BackwardDifferentiateExpr>(); + } + void fillDifferentiateExpr(DifferentiateExpr* resultDiffExpr, SemanticsVisitor* semantics, Expr* funcExpr) override + { + resultDiffExpr->baseFunction = funcExpr; + auto baseFuncType = getBaseFunctionType(semantics, funcExpr); + if (!baseFuncType) + { + resultDiffExpr->type = semantics->getASTBuilder()->getErrorType(); + semantics->getSink()->diagnose(funcExpr, Diagnostics::expectedFunction, funcExpr->type.type); + } + resultDiffExpr->type = semantics->getBackwardDiffFuncType(baseFuncType); + if (auto declRefExpr = as<DeclRefExpr>(getInnerMostExprFromHigherOrderExpr(funcExpr))) + { + if (auto funcDecl = declRefExpr->declRef.as<CallableDecl>()) + { + for (auto param : funcDecl.getDecl()->getParameters()) + { + resultDiffExpr->newParameterNames.add(param->getName()); + } + } + resultDiffExpr->newParameterNames.add(semantics->getName("resultGradient")); + } + } + }; - Expr* SemanticsExprVisitor::visitBackwardDifferentiateExpr(BackwardDifferentiateExpr* expr) + static Expr* _checkDifferentiateExpr( + SemanticsVisitor* semantics, + DifferentiateExpr* expr, + DifferentiateExprCheckingActions* actions) { // Check/Resolve inner function declaration. - expr->baseFunction = CheckTerm(expr->baseFunction); + expr->baseFunction = semantics->CheckTerm(expr->baseFunction); - // For now we only support using higher order expr as callee in an invoke expr. - // The actual type of the higher order function will be derived during resolve invoke. - expr->type = m_astBuilder->getBottomType(); + auto astBuilder = semantics->getASTBuilder(); + // If base is overloaded expr, we want to return an overloaded expr as check result. + // This is done by pushing the `differentiate` operator to each item in the overloaded expr. + if (auto overloadedExpr = as<OverloadedExpr>(expr->baseFunction)) + { + OverloadedExpr2* result = astBuilder->create<OverloadedExpr2>(); + for (auto item : overloadedExpr->lookupResult2) + { + auto lookupResultExpr = semantics->ConstructLookupResultExpr(item, + nullptr, + expr->loc, + nullptr); + auto candidateExpr = actions->createDifferentiateExpr(semantics); + actions->fillDifferentiateExpr(candidateExpr, semantics, lookupResultExpr); + result->candidiateExprs.add(candidateExpr); + } + result->type.type = astBuilder->getOverloadedType(); + return result; + } + else if (auto overloadedExpr2 = as<OverloadedExpr2>(expr->baseFunction)) + { + OverloadedExpr2* result = astBuilder->create<OverloadedExpr2>(); + for (auto item : overloadedExpr2->candidiateExprs) + { + auto candidateExpr = actions->createDifferentiateExpr(semantics); + actions->fillDifferentiateExpr(candidateExpr, semantics, item); + result->candidiateExprs.add(candidateExpr); + } + result->type.type = astBuilder->getOverloadedType(); + return result; + } + + actions->fillDifferentiateExpr(expr, semantics, expr->baseFunction); return expr; } + Expr* SemanticsExprVisitor::visitForwardDifferentiateExpr(ForwardDifferentiateExpr* expr) + { + ForwardDifferentiateExprCheckingActions actions; + return _checkDifferentiateExpr(this, expr, &actions); + } + + Expr* SemanticsExprVisitor::visitBackwardDifferentiateExpr(BackwardDifferentiateExpr* expr) + { + BackwardDifferentiateExprCheckingActions actions; + return _checkDifferentiateExpr(this, expr, &actions); + } + Expr* SemanticsExprVisitor::visitGetArrayLengthExpr(GetArrayLengthExpr* expr) { expr->arrayExpr = CheckTerm(expr->arrayExpr); @@ -2923,7 +3038,7 @@ namespace Slang // because vectors are also declaration reference types... // // Also note: the way this is done right now means that the ability - // to swizzle vectors interferes with any chance of looking up + // to swizzle vectors interferes with any chance o<f looking up // members via extension, for vector or scalar types. // // TODO: Matrix swizzles probably need to be handled at some point. diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index e7681212f..c4c32a681 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -190,6 +190,9 @@ namespace Slang // Reference to the declaration being applied LookupResultItem item; + // The expression when flavor is Expr. + Expr* exprVal = nullptr; + // Type of function being applied (for cases where `item` is not used) FuncType* funcType = nullptr; @@ -712,9 +715,9 @@ namespace Slang Type* getDifferentialPairType(Type* primalType); - // Convert a function's original type to it's JVP type. - Type* processJVPFuncType(FuncType* originalType); - Type* processBackwardDiffFuncType(FuncType* originalType); + // Convert a function's original type to it's forward/backward diff'd type. + Type* getForwardDiffFuncType(FuncType* originalType); + Type* getBackwardDiffFuncType(FuncType* originalType); /// Registers a type as conforming to IDifferentiable, along with a witness /// describing the relationship. diff --git a/source/slang/slang-check-modifier.cpp b/source/slang/slang-check-modifier.cpp index aa28571a7..9742e69bb 100644 --- a/source/slang/slang-check-modifier.cpp +++ b/source/slang/slang-check-modifier.cpp @@ -648,7 +648,6 @@ namespace Slang getSink()->diagnose(primalFunc, Slang::Diagnostics::invalidCustomDerivative, as<Decl>(attrTarget)); return false; } - forwardDerivativeOfAttr->funcExpr = primalFunc; } else if (auto comInterfaceAttr = as<ComInterfaceAttribute>(attr)) diff --git a/source/slang/slang-check-overload.cpp b/source/slang/slang-check-overload.cpp index fe9de9433..83774303b 100644 --- a/source/slang/slang-check-overload.cpp +++ b/source/slang/slang-check-overload.cpp @@ -714,22 +714,7 @@ namespace Slang callExpr->originalFunctionExpr = callExpr->functionExpr; callExpr->type = QualType(candidate.resultType); - - // If the callee is the result of a higher-order function invocation, - // set it's base function to the declaration corresponding to the - // resolved overload. - // - if (auto higherOrderInvoke = as<HigherOrderInvokeExpr>(callExpr->functionExpr)) - { - higherOrderInvoke->baseFunction = ConstructLookupResultExpr( - candidate.item, - baseExpr, - higherOrderInvoke->loc, - callExpr->functionExpr); - - higherOrderInvoke->type = candidate.funcType; - } - + callExpr->functionExpr = candidate.exprVal; return callExpr; } @@ -1252,10 +1237,19 @@ namespace Slang // to match it up with the arguments accordingly... if (auto funcDeclRef = partiallySpecializedInnerRef.as<CallableDecl>()) { - auto params = getParameters(funcDeclRef).toArray(); + List<Type*> paramTypes; + if (!innerParameterTypes) + { + auto params = getParameters(funcDeclRef).toArray(); + for (auto param : params) + { + paramTypes.add(getType(m_astBuilder, param)); + } + innerParameterTypes = ¶mTypes; + } Index valueArgCount = context.getArgCount(); - Index valueParamCount = params.getCount(); + Index valueParamCount = innerParameterTypes->getCount(); // If there are too many arguments, we cannot possibly have a match. // @@ -1295,7 +1289,7 @@ namespace Slang TryUnifyTypes( constraints, context.getArgTypeForInference(aa, this), - (!innerParameterTypes) ? getType(m_astBuilder, params[aa]) : (*innerParameterTypes)[aa]); + (*innerParameterTypes)[aa]); } } else @@ -1495,6 +1489,11 @@ namespace Slang // for anything applicable. AddDeclRefOverloadCandidates(LookupResultItem(declRefExpr->declRef), context); } + else if (auto higherOrderExpr = as<HigherOrderInvokeExpr>(funcExpr)) + { + // The expression is the result of a higher order function application. + AddHigherOrderOverloadCandidates(higherOrderExpr, context); + } else if (auto funcType = as<FuncType>(funcExprType)) { // TODO(tfoley): deprecate this path... @@ -1511,11 +1510,6 @@ namespace Slang AddOverloadCandidates(item, context); } } - else if (auto higherOrderExpr = as<HigherOrderInvokeExpr>(funcExpr)) - { - // The expression is the result of a higher order function application. - AddHigherOrderOverloadCandidates(higherOrderExpr, context); - } else if (auto partiallyAppliedGenericExpr = as<PartiallyAppliedGenericExpr>(funcExpr)) { // A partially-applied generic is allowed as an overload candidate, @@ -1550,90 +1544,43 @@ namespace Slang // if-else ladder. if (auto expr = as<HigherOrderInvokeExpr>(funcExpr)) { - if (auto origFuncType = as<FuncType>(expr->baseFunction->type)) + auto funcDeclRefExpr = as<DeclRefExpr>(getInnerMostExprFromHigherOrderExpr(expr->baseFunction)); + if (!funcDeclRefExpr) + return; + if (auto baseFuncDeclRef = funcDeclRefExpr->declRef.as<CallableDecl>()) { - - auto baseFuncDeclRef = as<DeclRefExpr>(expr->baseFunction)->declRef.as<CallableDecl>(); - SLANG_ASSERT(baseFuncDeclRef); - + // Base is a normal or fully specialized generic function. OverloadCandidate candidate; candidate.flavor = OverloadCandidate::Flavor::Expr; - if (auto fwdExpr = as<ForwardDifferentiateExpr>(expr)) - { - // Case: __fwd_diff(name-resolved-to-decl-ref) - candidate.funcType = as<FuncType>(processJVPFuncType(origFuncType)); - } - else if (auto bwdExpr = as<BackwardDifferentiateExpr>(expr)) + if (auto diffExpr = as<DifferentiateExpr>(expr)) { - // Case: __bwd_diff(name-resolved-to-decl-ref) - candidate.funcType = as<FuncType>(processBackwardDiffFuncType(origFuncType)); + candidate.funcType = as<FuncType>(diffExpr->type.type); } candidate.resultType = candidate.funcType->getResultType(); candidate.item = LookupResultItem(baseFuncDeclRef); - + candidate.exprVal = expr; AddOverloadCandidate(context, candidate); } - else if (auto origOverloadedType = as<OverloadGroupType>(expr->baseFunction->type)) - { - - if (auto overloadExpr = as<OverloadedExpr>(expr->baseFunction)) - { - for (auto item : overloadExpr->lookupResult2.items) - { - auto funcType = as<FuncType>(GetTypeForDeclRef(item.declRef, item.declRef.decl->loc)); - if (!funcType) - continue; - if (auto fwdExpr = as<ForwardDifferentiateExpr>(expr)) - { - // Case: __fwd_diff(name-resolved-to-decl-ref) - funcType = as<FuncType>(processJVPFuncType(funcType)); - } - else if (auto bwdExpr = as<BackwardDifferentiateExpr>(expr)) - { - // Case: __bwd_diff(name-resolved-to-decl-ref) - funcType = as<FuncType>(processBackwardDiffFuncType(funcType)); - } - if (!funcType) - continue; - OverloadCandidate candidate; - candidate.flavor = OverloadCandidate::Flavor::Expr; - candidate.funcType = funcType; - candidate.resultType = candidate.funcType->getResultType(); - candidate.item = LookupResultItem(item.declRef); - - AddOverloadCandidate(context, candidate); - } - } - else - { - // Unhandled overload expr. - funcExpr->type = this->getASTBuilder()->getErrorType(); - getSink()->diagnose(funcExpr->loc, - Diagnostics::unimplemented, - funcExpr->type); - } - } - else if (auto baseFuncGenericDeclRef = as<DeclRefExpr>(expr->baseFunction)->declRef.as<GenericDecl>()) + else if (auto baseFuncGenericDeclRef = funcDeclRefExpr->declRef.as<GenericDecl>()) { - // Get inner function DeclRef<Decl> unspecializedInnerRef = DeclRef<Decl>( getInner(baseFuncGenericDeclRef), baseFuncGenericDeclRef.substitutions); - - // Pull parameter list of inner function. - auto funcType = getFuncType(this->getASTBuilder(), unspecializedInnerRef.as<CallableDecl>()); // Process func type to generate JVP func type. - auto jvpFuncType = as<ForwardDifferentiateExpr>(expr) ? - as<FuncType>(processJVPFuncType(funcType)) : - as<FuncType>(processBackwardDiffFuncType(funcType)); + auto diffFuncType = as<FuncType>(expr->type.type); + if (!diffFuncType) + { + // This shouldn't happen, but we check to be safe. + return; + } // Extract parameter list from processed type. List<Type*> paramTypes; - for(UIndex ii = 0; ii < jvpFuncType->getParamCount(); ii++) - paramTypes.add(jvpFuncType->getParamType(ii)); + for (UIndex ii = 0; ii < diffFuncType->getParamCount(); ii++) + paramTypes.add(diffFuncType->getParamType(ii)); // Try to infer generic arguments, based on the updated context. DeclRef<Decl> innerRef = inferGenericArguments( @@ -1641,39 +1588,39 @@ namespace Slang context, nullptr, ¶mTypes); - + + OverloadCandidate candidate; + candidate.flavor = OverloadCandidate::Flavor::Expr; if (innerRef) { - OverloadCandidate candidate; - candidate.flavor = OverloadCandidate::Flavor::Expr; - - // Note that we call processJVPFuncType() again here - // in order to process the specialized version of the original func type. - // This could potentially be a declRef.substitute(jvpFuncType) - // - if (auto fwdExpr = as<ForwardDifferentiateExpr>(expr)) - { - // Case: __fwd_diff(name-resolved-to-generic-decl) - candidate.funcType = as<FuncType>(processJVPFuncType( - getFuncType(this->getASTBuilder(), innerRef.as<CallableDecl>()))); - } - else if (auto bwdExpr = as<BackwardDifferentiateExpr>(expr)) - { - // Case: __bwd_diff(name-resolved-to-generic-decl) - candidate.funcType = as<FuncType>(processBackwardDiffFuncType( - getFuncType(this->getASTBuilder(), innerRef.as<CallableDecl>()))); - } - - candidate.resultType = candidate.funcType->getResultType(); + diffFuncType = as<FuncType>(innerRef.substitute(m_astBuilder, diffFuncType)); candidate.item = LookupResultItem(innerRef); - - AddOverloadCandidate(context, candidate); } else { - SLANG_UNEXPECTED("Could not resolve generic candidate"); + candidate.item = LookupResultItem(funcDeclRefExpr->declRef); } + candidate.funcType = as<FuncType>(diffFuncType); + candidate.resultType = candidate.funcType->getResultType(); + // Substitute all types in the high-order expression chain. + Expr* inner = expr; + HigherOrderInvokeExpr* lastInner = nullptr; + while (auto hoInner = as<HigherOrderInvokeExpr>(inner)) + { + lastInner = hoInner; + hoInner->type = innerRef.substitute(m_astBuilder, hoInner->type.type); + inner = hoInner->baseFunction; + } + // Set inner expression to resolved declref expr. + if (lastInner) + { + auto baseExpr = GetBaseExpr(funcDeclRefExpr); + lastInner->baseFunction = ConstructLookupResultExpr(candidate.item, baseExpr, funcDeclRefExpr->loc, funcDeclRefExpr); + } + candidate.exprVal = expr; + expr->type.type = diffFuncType; + AddOverloadCandidate(context, candidate); } else { @@ -1683,6 +1630,7 @@ namespace Slang Diagnostics::expectedFunction, funcExpr->type); } + } } @@ -1769,18 +1717,7 @@ namespace Slang context.args = expr->arguments.getBuffer(); context.loc = expr->loc; - if (auto funcMemberExpr = as<MemberExpr>(funcExpr)) - { - context.baseExpr = funcMemberExpr->baseExpression; - } - else if (auto funcOverloadExpr = as<OverloadedExpr>(funcExpr)) - { - context.baseExpr = funcOverloadExpr->base; - } - else if (auto funcOverloadExpr2 = as<OverloadedExpr2>(funcExpr)) - { - context.baseExpr = funcOverloadExpr2->base; - } + context.baseExpr = GetBaseExpr(funcExpr); // TODO: We should have a special case here where an `InvokeExpr` // with a single argument where the base/func expression names diff --git a/source/slang/slang-ir-dce.cpp b/source/slang/slang-ir-dce.cpp index d58e307da..14edf21d7 100644 --- a/source/slang/slang-ir-dce.cpp +++ b/source/slang/slang-ir-dce.cpp @@ -305,6 +305,17 @@ bool shouldInstBeLiveIfParentIsLive(IRInst* inst, IRDeadCodeEliminationOptions o { return true; } + if (inst->findDecoration<IRImportDecoration>()) + { + if (inst->findDecoration<IRForwardDerivativeDecoration>()) + return true; + if (auto genInst = as<IRGeneric>(inst)) + { + auto inner = findInnerMostGenericReturnVal(genInst); + if (inner->findDecoration<IRForwardDerivativeDecoration>()) + return true; + } + } } if (options.keepLayoutsAlive && inst->findDecoration<IRLayoutDecoration>()) diff --git a/source/slang/slang-ir-diff-jvp.cpp b/source/slang/slang-ir-diff-jvp.cpp index 1597c80d1..152601dbd 100644 --- a/source/slang/slang-ir-diff-jvp.cpp +++ b/source/slang/slang-ir-diff-jvp.cpp @@ -456,8 +456,15 @@ struct DifferentialPairTypeBuilder IRInst* _createDiffPairType(IRType* origBaseType, IRType* diffType) { - SLANG_ASSERT(!as<IRParam>(origBaseType)); - SLANG_ASSERT(diffType); + switch (origBaseType->getOp()) + { + case kIROp_lookup_interface_method: + case kIROp_Specialize: + case kIROp_Param: + return nullptr; + default: + break; + } if (diffType->getOp() != kIROp_DifferentialBottomType) { IRBuilder builder(sharedContext->sharedBuilder); @@ -511,6 +518,8 @@ struct DifferentialPairTypeBuilder } auto diffType = getDiffTypeFromPairType(builder, pairType); + if (!diffType) + return result; result.loweredType = _createDiffPairType(pairType->getValueType(), (IRType*)diffType); result.isTrivial = (diffType->getOp() == kIROp_DifferentialBottomType); pairTypeCache.Add(originalPairType, result); @@ -1431,10 +1440,10 @@ struct JVPTranscriber else { getSink()->diagnose(origSpecialize->sourceLoc, - Diagnostics::unexpected, - "should not be attempting to differentiate anything specialized here."); + Diagnostics::unexpected, + "should not be attempting to differentiate anything specialized here."); } - + return InstPair(nullptr, nullptr); } @@ -2740,7 +2749,16 @@ struct JVPDerivativeContext : public InstPassBase { case kIROp_ForwardDifferentiate: case kIROp_BackwardDifferentiate: - autoDiffWorkList.add(inst); + // Only process now if the operand is a materialized function. + switch (inst->getOperand(0)->getOp()) + { + case kIROp_Func: + case kIROp_Specialize: + autoDiffWorkList.add(inst); + break; + default: + break; + } break; default: break; @@ -2752,59 +2770,63 @@ struct JVPDerivativeContext : public InstPassBase // Process collected `ForwardDifferentiate` insts and replace them with placeholders for // differentiated functions. + transcriberStorage.followUpFunctionsToTranscribe.clear(); backwardTranscriberStorage.followUpFunctionsToTranscribe.clear(); for (auto differentiateInst : autoDiffWorkList) { IRInst* baseInst = differentiateInst->getOperand(0); - - if (auto baseFunction = as<IRGlobalValueWithCode>(baseInst)) + if (as<IRForwardDifferentiate>(differentiateInst)) { - if (as<IRForwardDifferentiate>(differentiateInst)) + if (auto existingDiffFunc = lookupJVPReference(baseInst)) + { + differentiateInst->replaceUsesWith(existingDiffFunc); + differentiateInst->removeAndDeallocate(); + } + else if (isMarkedForForwardDifferentiation(baseInst)) { - if (auto existingDiffFunc = lookupJVPReference(baseFunction)) + if (as<IRFunc>(baseInst) || as<IRGeneric>(baseInst)) { - differentiateInst->replaceUsesWith(existingDiffFunc); + IRInst* diffFunc = transcriberStorage.transcribe(builder, baseInst); + SLANG_ASSERT(diffFunc); + differentiateInst->replaceUsesWith(diffFunc); differentiateInst->removeAndDeallocate(); } - else if (isMarkedForForwardDifferentiation(baseFunction)) + else { - if (as<IRFunc>(baseFunction) || as<IRGeneric>(baseFunction)) - { - IRInst* diffFunc = transcriberStorage.transcribe(builder, baseFunction); - SLANG_ASSERT(diffFunc); - differentiateInst->replaceUsesWith(diffFunc); - differentiateInst->removeAndDeallocate(); - } - else - { - getSink()->diagnose(differentiateInst->sourceLoc, - Diagnostics::internalCompilerError, - "Unexpected instruction. Expected func or generic"); - } + getSink()->diagnose(differentiateInst->sourceLoc, + Diagnostics::internalCompilerError, + "Unexpected instruction. Expected func or generic"); } } - else if (as<IRBackwardDifferentiate>(differentiateInst)) + else { - if (isMarkedForBackwardDifferentiation(baseFunction)) + getSink()->diagnose(differentiateInst->sourceLoc, + Diagnostics::internalCompilerError, + "Requested differentiation on a function that isn't marked as differentiable."); + } + + } + else if (as<IRBackwardDifferentiate>(differentiateInst)) + { + if (isMarkedForBackwardDifferentiation(baseInst)) + { + if (as<IRFunc>(baseInst)) { - if (as<IRFunc>(baseFunction) || as<IRGeneric>(baseFunction)) - { - IRInst* diffFunc = - backwardTranscriberStorage - .transcribeFuncHeader(builder, (IRFunc*)baseFunction) - .differential; - SLANG_ASSERT(diffFunc); - differentiateInst->replaceUsesWith(diffFunc); - differentiateInst->removeAndDeallocate(); - } - else - { - getSink()->diagnose(differentiateInst->sourceLoc, - Diagnostics::internalCompilerError, - "Unexpected instruction. Expected func or generic"); - } + IRInst* diffFunc = + backwardTranscriberStorage + .transcribeFuncHeader(builder, (IRFunc*)baseInst) + .differential; + SLANG_ASSERT(diffFunc); + differentiateInst->replaceUsesWith(diffFunc); + differentiateInst->removeAndDeallocate(); + } + else + { + getSink()->diagnose(differentiateInst->sourceLoc, + Diagnostics::internalCompilerError, + "Unexpected instruction. Expected func or generic"); } } } @@ -3118,18 +3140,9 @@ struct JVPDerivativeContext : public InstPassBase // Checks decorators to see if the function should // be differentiated (kIROp_ForwardDifferentiableDecoration) // - bool isMarkedForForwardDifferentiation(IRGlobalValueWithCode* callable) + bool isMarkedForForwardDifferentiation(IRInst* callable) { - for (auto decoration = callable->getFirstDecoration(); - decoration; - decoration = decoration->getNextDecoration()) - { - if (decoration->getOp() == kIROp_ForwardDifferentiableDecoration) - { - return true; - } - } - return false; + return callable->findDecoration<IRForwardDifferentiableDecoration>() != nullptr; } IRStringLit* getForwardDerivativeFuncName(IRInst* func) @@ -3153,18 +3166,9 @@ struct JVPDerivativeContext : public InstPassBase // Checks decorators to see if the function should // be differentiated (kIROp_ForwardDifferentiableDecoration) // - bool isMarkedForBackwardDifferentiation(IRGlobalValueWithCode* callable) + bool isMarkedForBackwardDifferentiation(IRInst* callable) { - for (auto decoration = callable->getFirstDecoration(); - decoration; - decoration = decoration->getNextDecoration()) - { - if (decoration->getOp() == kIROp_BackwardDifferentiableDecoration) - { - return true; - } - } - return false; + return callable->findDecoration<IRBackwardDifferentiableDecoration>() != nullptr; } IRStringLit* getBackwardDerivativeFuncName(IRInst* func) diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp index cf0293f0d..8559103ae 100644 --- a/source/slang/slang-ir-link.cpp +++ b/source/slang/slang-ir-link.cpp @@ -1452,9 +1452,9 @@ LinkedIR linkIR( List<IRModule*> irModules; // Link stdlib modules. - auto builtinLinkage = static_cast<Session*>(linkage->getGlobalSession())->getBuiltinLinkage(); - for (auto& m : builtinLinkage->mapNameToLoadedModules) - irModules.add(m.Value->getIRModule()); + auto& stdlibModules = static_cast<Session*>(linkage->getGlobalSession())->stdlibModules; + for (auto& m : stdlibModules) + irModules.add(m->getIRModule()); // Link modules in the program. program->enumerateIRModules([&](IRModule* irModule) diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 1a8f20f1a..261f64130 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -19,6 +19,8 @@ namespace Slang void printDiagnosticArg(StringBuilder& sb, IRInst* irObject) { + if (!irObject) + return; if (auto nameHint = irObject->findDecoration<IRNameHintDecoration>()) sb << nameHint->getName(); } diff --git a/source/slang/slang-language-server.cpp b/source/slang/slang-language-server.cpp index bb662db91..15217dcac 100644 --- a/source/slang/slang-language-server.cpp +++ b/source/slang/slang-language-server.cpp @@ -987,6 +987,84 @@ SlangResult LanguageServer::signatureHelp( response.signatures.add(sigInfo); }; + auto addExpr = [&](Expr* expr) + { + auto higherOrderExpr = as<HigherOrderInvokeExpr>(expr); + if (!higherOrderExpr) + return; + auto funcType = as<FuncType>(higherOrderExpr->type); + if (!funcType) + return; + auto declRefExpr = as<DeclRefExpr>(getInnerMostExprFromHigherOrderExpr(higherOrderExpr)); + if (!declRefExpr) + return; + if (!declRefExpr->declRef.getDecl()) + return; + + SignatureInformation sigInfo; + + List<Slang::Range<Index>> paramRanges; + ASTPrinter printer( + version->linkage->getASTBuilder(), + ASTPrinter::OptionFlag::ParamNames | ASTPrinter::OptionFlag::NoInternalKeywords | + ASTPrinter::OptionFlag::SimplifiedBuiltinType); + + printer.addDeclKindPrefix(declRefExpr->declRef.getDecl()); + auto inner = higherOrderExpr; + int closingParentCount = 0; + while (inner) + { + printer.getStringBuilder() << getHigherOrderOperatorName(inner) << "("; + closingParentCount++; + inner = as<HigherOrderInvokeExpr>(inner->baseFunction); + } + printer.addDeclPath(declRefExpr->declRef); + for (int i = 0; i < closingParentCount; i++) + printer.getStringBuilder() << ")"; + bool isFirst = true; + printer.getStringBuilder() << "("; + int paramIndex = 0; + for (auto param : funcType->paramTypes) + { + if (!isFirst) + printer.getStringBuilder() << ", "; + Slang::Range<Index> range; + range.begin = printer.getStringBuilder().getLength(); + if (paramIndex < higherOrderExpr->newParameterNames.getCount()) + { + if (higherOrderExpr->newParameterNames[paramIndex]) + { + printer.getStringBuilder() << higherOrderExpr->newParameterNames[paramIndex]->text << ": "; + } + } + printer.addType(param); + range.end = printer.getStringBuilder().getLength(); + paramRanges.add(range); + isFirst = false; + paramIndex++; + } + printer.getStringBuilder() << ") -> "; + printer.addType(funcType->getResultType()); + + sigInfo.label = printer.getString(); + + StringBuilder docSB; + auto humaneLoc = version->linkage->getSourceManager()->getHumaneLoc(declRefExpr->declRef.getLoc(), SourceLocType::Actual); + _tryGetDocumentation(docSB, version, declRefExpr->declRef.getDecl()); + appendDefinitionLocation(docSB, m_workspace, humaneLoc); + sigInfo.documentation.value = docSB.ProduceString(); + sigInfo.documentation.kind = "markdown"; + + for (auto& range : paramRanges) + { + ParameterInformation paramInfo; + paramInfo.label[0] = (uint32_t)range.begin; + paramInfo.label[1] = (uint32_t)range.end; + sigInfo.parameters.add(paramInfo); + } + response.signatures.add(sigInfo); + }; + auto addFuncType = [&](FuncType* funcType) { SignatureInformation sigInfo; @@ -1045,6 +1123,17 @@ SlangResult LanguageServer::signatureHelp( addDeclRef(item.declRef); } } + else if (auto overloadedExpr2 = as<OverloadedExpr2>(funcExpr)) + { + for (auto item : overloadedExpr2->candidiateExprs) + { + addExpr(item); + } + } + else if (auto higherOrder = as<HigherOrderInvokeExpr>(funcExpr)) + { + addExpr(higherOrder); + } else if (auto funcType = as<FuncType>(funcExpr->type.type)) { addFuncType(funcType); diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index dc0af4f96..4c3f4d646 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -19,7 +19,6 @@ #include "slang-ir-string-hash.h" #include "slang-ir-clone.h" #include "slang-ir-lower-error-handling.h" - #include "slang-mangle.h" #include "slang-type-layout.h" #include "slang-visitor.h" @@ -6040,11 +6039,11 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> IRGenContext* subContext, WitnessTable* astWitnessTable, IRWitnessTable* irWitnessTable, - Dictionary<WitnessTable*, IRWitnessTable*> mapASTToIRWitnessTable) + Dictionary<WitnessTable*, IRWitnessTable*> &mapASTToIRWitnessTable) { auto subBuilder = subContext->irBuilder; - for(auto entry : astWitnessTable->requirementList) + for(auto entry : astWitnessTable->requirementDictionary) { auto requiredMemberDecl = entry.Key; auto satisfyingWitness = entry.Value; @@ -8273,6 +8272,16 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // Reset cursor. subContext->irBuilder->setInsertInto(irFunc); } + + // For convenience, ensure that any additional global + // values that were emitted while outputting the function + // body appear before the function itself in the list + // of global values. + irFunc->moveToEnd(); + + // If this function is defined inside an interface, add a reference to the IRFunc from + // the interface's type definition. + auto finalVal = finishOuterGenerics(subBuilder, irFunc, outerGeneric); if (auto attr = decl->findModifier<ForwardDerivativeOfAttribute>()) { @@ -8281,7 +8290,10 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> NestedContext originalContextFunc(this); auto originalSubBuilder = originalContextFunc.getBuilder(); auto originalSubContext = originalContextFunc.getContext(); - + if (auto outterGeneric = getOuterGeneric(irFunc)) + originalSubBuilder->setInsertBefore(outterGeneric); + else + originalSubBuilder->setInsertBefore(irFunc); auto originalFuncDecl = as<FunctionDeclBase>(originalDeclRefExpr->declRef.getDecl()); SLANG_RELEASE_ASSERT(originalFuncDecl); @@ -8294,27 +8306,10 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> auto derivativeFuncVal = lowerRValueExpr(originalSubContext, attr->backDeclRef); originalSubBuilder->addForwardDerivativeDecoration(originalFuncVal, derivativeFuncVal.val); } - - subContext->irBuilder->setInsertInto(irFunc->getParent()); - auto loweredVal = lowerRValueExpr(subContext, attr->funcExpr); - - SLANG_ASSERT(loweredVal.flavor == LoweredValInfo::Flavor::Simple); - IRInst* originalFunc = loweredVal.val; - getBuilder()->addDecoration(irFunc, kIROp_ForwardDerivativeDecoration, originalFunc); - + getBuilder()->addForwardDifferentiableDecoration(irFunc); subContext->irBuilder->setInsertInto(irFunc); + finalVal->moveToEnd(); } - - // For convenience, ensure that any additional global - // values that were emitted while outputting the function - // body appear before the function itself in the list - // of global values. - irFunc->moveToEnd(); - - // If this function is defined inside an interface, add a reference to the IRFunc from - // the interface's type definition. - auto finalVal = finishOuterGenerics(subBuilder, irFunc, outerGeneric); - return LoweredValInfo::simple(finalVal); } diff --git a/source/slang/slang-mangle.cpp b/source/slang/slang-mangle.cpp index f88549e41..bdb5465e9 100644 --- a/source/slang/slang-mangle.cpp +++ b/source/slang/slang-mangle.cpp @@ -315,7 +315,7 @@ namespace Slang // Inheritance declarations don't have meaningful names, // and so we should emit them based on the type // that is doing the inheriting. - if(auto inheritanceDeclRef = declRef.as<InheritanceDecl>()) + if(auto inheritanceDeclRef = declRef.as<TypeConstraintDecl>()) { emit(context, "I"); emitType(context, getSup(context->astBuilder, inheritanceDeclRef)); diff --git a/source/slang/slang-syntax.cpp b/source/slang/slang-syntax.cpp index 80f57905f..4f05bc936 100644 --- a/source/slang/slang-syntax.cpp +++ b/source/slang/slang-syntax.cpp @@ -15,16 +15,22 @@ namespace Slang void printDiagnosticArg(StringBuilder& sb, Decl* decl) { + if (!decl) + return; sb << getText(decl->getName()); } void printDiagnosticArg(StringBuilder& sb, Type* type) { + if (!type) + return; type->toText(sb); } void printDiagnosticArg(StringBuilder& sb, Val* val) { + if (!val) + return; val->toText(sb); } @@ -44,13 +50,17 @@ void printDiagnosticArg(StringBuilder& sb, QualType const& type) sb << "<null>"; } -SourceLoc const& getDiagnosticPos(SyntaxNode const* syntax) +SourceLoc getDiagnosticPos(SyntaxNode const* syntax) { + if (!syntax) + return SourceLoc(); return syntax->loc; } -SourceLoc const& getDiagnosticPos(TypeExp const& typeExp) +SourceLoc getDiagnosticPos(TypeExp const& typeExp) { + if (!typeExp.exp) + return SourceLoc(); return typeExp.exp->loc; } @@ -365,7 +375,6 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt SLANG_ASSERT(!requirementDictionary.ContainsKey(decl)); requirementDictionary.Add(decl, witness); - requirementList.add(KeyValuePair<Decl*, RequirementWitness>(decl, witness)); } // @@ -1169,10 +1178,10 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt } ThisTypeSubstitution* findThisTypeSubstitution( - Substitutions* substs, + const Substitutions* substs, InterfaceDecl* interfaceDecl) { - for(Substitutions* s = substs; s; s = s->outer) + for(const Substitutions* s = substs; s; s = s->outer) { auto thisTypeSubst = as<ThisTypeSubstitution>(s); if(!thisTypeSubst) @@ -1181,7 +1190,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt if(thisTypeSubst->interfaceDecl != interfaceDecl) continue; - return thisTypeSubst; + return const_cast<ThisTypeSubstitution*>(thisTypeSubst); } return nullptr; @@ -1236,7 +1245,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt } // Prints a partially qualified type name with generic substitutions. - static void _printNestedDecl(const Substitutions* substitutions, Decl* decl, StringBuilder& out) + void _printNestedDecl(const Substitutions* substitutions, Decl* decl, StringBuilder& out) { // If there is a parent scope for the declaration, print it first. // Exclude top-level namespaces like `tu0` or `core`. @@ -1258,12 +1267,28 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt out << "."; } } - - // Print this type's name. - auto name = decl->getName(); - if (name) + // If we have a ThisTypeSubstitution to an interface decl, print the substituted sub + // type instead. + for (;;) { - out << name->text; + if (auto interfaceDecl = as<InterfaceDecl>(decl)) + { + if (auto thisSubst = findThisTypeSubstitution(substitutions, interfaceDecl)) + { + if (auto subTypeWitness = as<SubtypeWitness>(thisSubst->witness)) + { + out << subTypeWitness->sub; + break; + } + } + } + // Otherwise, just print this type's name. + auto name = decl->getName(); + if (name) + { + out << name->text; + } + break; } // Look for generic substitutions on this type. @@ -1280,6 +1305,9 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt bool isFirst = true; for (const auto& it : genericSubstitution->getArgs()) { + // Don't print out witnesses. + if (as<Witness>(it)) + continue; if (!isFirst) out << ", "; isFirst = false; diff --git a/source/slang/slang-syntax.h b/source/slang/slang-syntax.h index 4e1900636..8f88ddb2d 100644 --- a/source/slang/slang-syntax.h +++ b/source/slang/slang-syntax.h @@ -276,7 +276,7 @@ namespace Slang // ThisTypeSubstitution* findThisTypeSubstitution( - Substitutions* substs, + const Substitutions* substs, InterfaceDecl* interfaceDecl); RequirementWitness tryLookUpRequirementWitness( diff --git a/tests/autodiff/high-order-diff-operator.slang b/tests/autodiff/high-order-diff-operator.slang new file mode 100644 index 000000000..dca67e9f3 --- /dev/null +++ b/tests/autodiff/high-order-diff-operator.slang @@ -0,0 +1,35 @@ +//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type +//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type + +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer +RWStructuredBuffer<float> outputBuffer; + +[ForwardDifferentiable] +float mySqr(float x) +{ + return x * x; +} + +[ForwardDifferentiable] +float f(float x) +{ + return mySqr(x * x); +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + // Given f(x) = x^4, + // f''(x) = 12 * x^2 + // Expect f''(4) = 192 + float.Differential t = 2; + outputBuffer[0] = __fwd_diff(__fwd_diff(f))( + DifferentialPair<DifferentialPair<float>>( + DifferentialPair<float>(4.0, 1.0), DifferentialPair<float>(1.0, 0.0))).d.d; + + // sin''(x) = cos'(x) = -sin(x). + // Expect sin''(Pi/2) = -1. + outputBuffer[1] = __fwd_diff(__fwd_diff(sin))( + DifferentialPair<DifferentialPair<float>>( + DifferentialPair<float>(float.getPi()/2, 1.0), DifferentialPair<float>(1.0, 0.0))).d.d; +} diff --git a/tests/autodiff/high-order-diff-operator.slang.expected.txt b/tests/autodiff/high-order-diff-operator.slang.expected.txt new file mode 100644 index 000000000..305a8e111 --- /dev/null +++ b/tests/autodiff/high-order-diff-operator.slang.expected.txt @@ -0,0 +1,5 @@ +type: float +192.000000 +-1.000000 +0.000000 +0.000000 diff --git a/tests/diagnostics/bad-operator-call.slang.expected b/tests/diagnostics/bad-operator-call.slang.expected index 14d1b858f..21cb6bd41 100644 --- a/tests/diagnostics/bad-operator-call.slang.expected +++ b/tests/diagnostics/bad-operator-call.slang.expected @@ -3,38 +3,38 @@ standard error = { tests/diagnostics/bad-operator-call.slang(18): error 39999: no overload for '+=' applicable to arguments of type (int, S) a += b; ^~ -core.meta.slang(2219): note 39999: candidate: __unsafeForceInlineEarly func +=<T, R:int, C:int>(out matrix<T,R,C>, T) -> matrix<T,R,C> -core.meta.slang(2211): note 39999: candidate: __unsafeForceInlineEarly func +=<T, R:int, C:int>(out matrix<T,R,C>, matrix<T,R,C>) -> matrix<T,R,C> -core.meta.slang(2203): note 39999: candidate: __unsafeForceInlineEarly func +=<T, N:int>(out vector<T,N>, T) -> vector<T,N> -core.meta.slang(2195): note 39999: candidate: __unsafeForceInlineEarly func +=<T, N:int>(out vector<T,N>, vector<T,N>) -> vector<T,N> -core.meta.slang(2187): note 39999: candidate: __unsafeForceInlineEarly func +=<T>(out T, T) -> T +core.meta.slang(2430): note 39999: candidate: func +=<T, R:int, C:int>(out matrix<T,R,C>, T) -> matrix<T,R,C> +core.meta.slang(2422): note 39999: candidate: func +=<T, R:int, C:int>(out matrix<T,R,C>, matrix<T,R,C>) -> matrix<T,R,C> +core.meta.slang(2414): note 39999: candidate: func +=<T, N:int>(out vector<T,N>, T) -> vector<T,N> +core.meta.slang(2406): note 39999: candidate: func +=<T, N:int>(out vector<T,N>, vector<T,N>) -> vector<T,N> +core.meta.slang(2398): note 39999: candidate: func +=<T>(out T, T) -> T tests/diagnostics/bad-operator-call.slang(20): error 39999: no overload for '+' applicable to arguments of type (int, S) a = a + b; ^ -core.meta.slang(2045): note 39999: candidate: __intrinsic_op func +(uintptr_t, uintptr_t) -> uintptr_t -core.meta.slang(2037): note 39999: candidate: __intrinsic_op func +(uint64_t, uint64_t) -> uint64_t -core.meta.slang(2029): note 39999: candidate: __intrinsic_op func +(uint, uint) -> uint -core.meta.slang(2021): note 39999: candidate: __intrinsic_op func +(uint16_t, uint16_t) -> uint16_t -core.meta.slang(2013): note 39999: candidate: __intrinsic_op func +(uint8_t, uint8_t) -> uint8_t -core.meta.slang(2005): note 39999: candidate: __intrinsic_op func +(double, double) -> double -core.meta.slang(1997): note 39999: candidate: __intrinsic_op func +(float, float) -> float -core.meta.slang(1989): note 39999: candidate: __intrinsic_op func +(half, half) -> half -core.meta.slang(1981): note 39999: candidate: __intrinsic_op func +(intptr_t, intptr_t) -> intptr_t -core.meta.slang(1973): note 39999: candidate: __intrinsic_op func +(int64_t, int64_t) -> int64_t +core.meta.slang(2256): note 39999: candidate: func +(uintptr_t, uintptr_t) -> uintptr_t +core.meta.slang(2248): note 39999: candidate: func +(uint64_t, uint64_t) -> uint64_t +core.meta.slang(2240): note 39999: candidate: func +(uint, uint) -> uint +core.meta.slang(2232): note 39999: candidate: func +(uint16_t, uint16_t) -> uint16_t +core.meta.slang(2224): note 39999: candidate: func +(uint8_t, uint8_t) -> uint8_t +core.meta.slang(2216): note 39999: candidate: func +(double, double) -> double +core.meta.slang(2208): note 39999: candidate: func +(float, float) -> float +core.meta.slang(2200): note 39999: candidate: func +(half, half) -> half +core.meta.slang(2192): note 39999: candidate: func +(intptr_t, intptr_t) -> intptr_t +core.meta.slang(2184): note 39999: candidate: func +(int64_t, int64_t) -> int64_t tests/diagnostics/bad-operator-call.slang(20): note 39999: 3 more overload candidates tests/diagnostics/bad-operator-call.slang(22): error 39999: no overload for '~' applicable to arguments of type (S) a = ~b; ^ -core.meta.slang(1914): note 39999: candidate: __prefix __intrinsic_op func ~(uintptr_t) -> uintptr_t -core.meta.slang(1910): note 39999: candidate: __prefix __intrinsic_op func ~(uint64_t) -> uint64_t -core.meta.slang(1906): note 39999: candidate: __prefix __intrinsic_op func ~(uint) -> uint -core.meta.slang(1902): note 39999: candidate: __prefix __intrinsic_op func ~(uint16_t) -> uint16_t -core.meta.slang(1898): note 39999: candidate: __prefix __intrinsic_op func ~(uint8_t) -> uint8_t -core.meta.slang(1894): note 39999: candidate: __prefix __intrinsic_op func ~(intptr_t) -> intptr_t -core.meta.slang(1890): note 39999: candidate: __prefix __intrinsic_op func ~(int64_t) -> int64_t -core.meta.slang(1886): note 39999: candidate: __prefix __intrinsic_op func ~(int) -> int -core.meta.slang(1882): note 39999: candidate: __prefix __intrinsic_op func ~(int16_t) -> int16_t -core.meta.slang(1878): note 39999: candidate: __prefix __intrinsic_op func ~(int8_t) -> int8_t +core.meta.slang(2125): note 39999: candidate: __prefix func ~(uintptr_t) -> uintptr_t +core.meta.slang(2121): note 39999: candidate: __prefix func ~(uint64_t) -> uint64_t +core.meta.slang(2117): note 39999: candidate: __prefix func ~(uint) -> uint +core.meta.slang(2113): note 39999: candidate: __prefix func ~(uint16_t) -> uint16_t +core.meta.slang(2109): note 39999: candidate: __prefix func ~(uint8_t) -> uint8_t +core.meta.slang(2105): note 39999: candidate: __prefix func ~(intptr_t) -> intptr_t +core.meta.slang(2101): note 39999: candidate: __prefix func ~(int64_t) -> int64_t +core.meta.slang(2097): note 39999: candidate: __prefix func ~(int) -> int +core.meta.slang(2093): note 39999: candidate: __prefix func ~(int16_t) -> int16_t +core.meta.slang(2089): note 39999: candidate: __prefix func ~(int8_t) -> int8_t tests/diagnostics/bad-operator-call.slang(27): error 30047: argument passed to parameter '0' must be l-value. a += c; ^ @@ -42,24 +42,24 @@ tests/diagnostics/bad-operator-call.slang(27): note 30048: argument was implicit tests/diagnostics/bad-operator-call.slang(31): error 39999: no overload for '+=' applicable to arguments of type (vector<float,3>, vector<int,4>) d += c; ^~ -core.meta.slang(2219): note 39999: candidate: __unsafeForceInlineEarly func +=<T, R:int, C:int>(out matrix<T,R,C>, T) -> matrix<T,R,C> -core.meta.slang(2211): note 39999: candidate: __unsafeForceInlineEarly func +=<T, R:int, C:int>(out matrix<T,R,C>, matrix<T,R,C>) -> matrix<T,R,C> -core.meta.slang(2203): note 39999: candidate: __unsafeForceInlineEarly func +=<T, N:int>(out vector<T,N>, T) -> vector<T,N> -core.meta.slang(2195): note 39999: candidate: __unsafeForceInlineEarly func +=<T, N:int>(out vector<T,N>, vector<T,N>) -> vector<T,N> -core.meta.slang(2187): note 39999: candidate: __unsafeForceInlineEarly func +=<T>(out T, T) -> T +core.meta.slang(2430): note 39999: candidate: func +=<T, R:int, C:int>(out matrix<T,R,C>, T) -> matrix<T,R,C> +core.meta.slang(2422): note 39999: candidate: func +=<T, R:int, C:int>(out matrix<T,R,C>, matrix<T,R,C>) -> matrix<T,R,C> +core.meta.slang(2414): note 39999: candidate: func +=<T, N:int>(out vector<T,N>, T) -> vector<T,N> +core.meta.slang(2406): note 39999: candidate: func +=<T, N:int>(out vector<T,N>, vector<T,N>) -> vector<T,N> +core.meta.slang(2398): note 39999: candidate: func +=<T>(out T, T) -> T tests/diagnostics/bad-operator-call.slang(33): error 39999: no overload for '+' applicable to arguments of type (vector<int,4>, vector<float,3>) d = c + d; ^ -core.meta.slang(2051): note 39999: candidate: __intrinsic_op func +<4>(vector<uintptr_t,4>, uintptr_t) -> vector<uintptr_t,4> -core.meta.slang(2049): note 39999: candidate: __intrinsic_op func +<3>(uintptr_t, vector<uintptr_t,3>) -> vector<uintptr_t,3> -core.meta.slang(2045): note 39999: candidate: __intrinsic_op func +(uintptr_t, uintptr_t) -> uintptr_t -core.meta.slang(2043): note 39999: candidate: __intrinsic_op func +<4>(vector<uint64_t,4>, uint64_t) -> vector<uint64_t,4> -core.meta.slang(2041): note 39999: candidate: __intrinsic_op func +<3>(uint64_t, vector<uint64_t,3>) -> vector<uint64_t,3> -core.meta.slang(2037): note 39999: candidate: __intrinsic_op func +(uint64_t, uint64_t) -> uint64_t -core.meta.slang(2035): note 39999: candidate: __intrinsic_op func +<4>(vector<uint,4>, uint) -> vector<uint,4> -core.meta.slang(2033): note 39999: candidate: __intrinsic_op func +<3>(uint, vector<uint,3>) -> vector<uint,3> -core.meta.slang(2029): note 39999: candidate: __intrinsic_op func +(uint, uint) -> uint -core.meta.slang(2027): note 39999: candidate: __intrinsic_op func +<4>(vector<uint16_t,4>, uint16_t) -> vector<uint16_t,4> +core.meta.slang(2262): note 39999: candidate: func +<4>(uintptr_t4, uintptr_t) -> uintptr_t4 +core.meta.slang(2260): note 39999: candidate: func +<3>(uintptr_t, uintptr_t3) -> uintptr_t3 +core.meta.slang(2256): note 39999: candidate: func +(uintptr_t, uintptr_t) -> uintptr_t +core.meta.slang(2254): note 39999: candidate: func +<4>(uint64_t4, uint64_t) -> uint64_t4 +core.meta.slang(2252): note 39999: candidate: func +<3>(uint64_t, uint64_t3) -> uint64_t3 +core.meta.slang(2248): note 39999: candidate: func +(uint64_t, uint64_t) -> uint64_t +core.meta.slang(2246): note 39999: candidate: func +<4>(uint4, uint) -> uint4 +core.meta.slang(2244): note 39999: candidate: func +<3>(uint, uint3) -> uint3 +core.meta.slang(2240): note 39999: candidate: func +(uint, uint) -> uint +core.meta.slang(2238): note 39999: candidate: func +<4>(uint16_t4, uint16_t) -> uint16_t4 tests/diagnostics/bad-operator-call.slang(33): note 39999: 29 more overload candidates } standard output = { |
