diff options
| author | Yong He <yonghe@outlook.com> | 2022-10-24 22:19:38 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-10-24 22:19:38 -0700 |
| commit | 41cb7c13e37ec32ffb6557d21da079d77151e136 (patch) | |
| tree | 38d2c44938e2679c42c5c0e73f5411e59015df93 /source/slang/slang-check-expr.cpp | |
| parent | 1093218d6f0e114eb9fa52d60ca525bf9dd9f98a (diff) | |
Rework differentiation of member access through `[DerivativeMember(DiffType.field)]` (#2460)
* wip: remove auto-diff for member access, add diff through property accessors.
* Fix getter-setter test.
* Fix getter-setter-multi test.
* Fix nested-jvp test.
* Use [DerivativeMember] attribute to differentiate through member access.
* Clean up.
* More cleanup.
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-check-expr.cpp')
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 157 |
1 files changed, 24 insertions, 133 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 745532c27..29b44e726 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -755,7 +755,7 @@ namespace Slang } else if (diffTypeLookupResult.isOverloaded()) { - SLANG_UNIMPLEMENTED_X("Ambiguous differential type declarations not supported"); + getSink()->diagnose(declRefType->declRef, Diagnostics::ambiguousReference, getName("Differential")); } else { @@ -774,7 +774,7 @@ namespace Slang } } - return nullptr; + return m_astBuilder->getErrorType(); } void SemanticsVisitor::maybeRegisterDifferentiableType(ASTBuilder* builder, Type* type) @@ -813,103 +813,6 @@ namespace Slang } } - Expr* SemanticsVisitor::maybeMakeDifferentialExpr(Expr* checkedTerm) - { - // Check that member lookups on differentiable types have appropriate differential - // getters and setters. - if (auto declRefExpr = as<DeclRefExpr>(checkedTerm)) - { - - // Check if we have a parent container. If yes, then checkedTerm is - // referencing a member of this parent. - // - auto parentType = DeclRefType::create(getASTBuilder(), declRefExpr->declRef.getParent()); - - // Check if we have an aggregate (i.e. struct-like) type. - // Ignore interfaces and the case when the term refers to a function - // - if (parentType->declRef.as<AggTypeDeclBase>() && - !parentType->declRef.as<InterfaceDecl>() && - !declRefExpr->declRef.as<CallableDecl>()) - { - // Check if the parent container type is differentiable. - if (auto parentDiffWitness = as<SubtypeWitness>( - tryGetInterfaceConformanceWitness( - parentType, getASTBuilder()->getDifferentiableInterface()))) - { - // If yes, the member in checkedTerm should have a differential getter and setter. - // Otherwise, <ERROR> - // - auto diffExpr = m_astBuilder->create<DifferentiableDeclRefExpr>(); - diffExpr->type = checkedTerm->type; - diffExpr->inner = checkedTerm; - - { - auto getterName = getName("__getDifferentialFor_" + declRefExpr->name->text); - auto getterResult = lookUpMember( - getASTBuilder(), - this, - getterName, - parentType, - Slang::LookupMask::Function, - Slang::LookupOptions::None); - - if (!getterResult.isValid()) - { - // Do nothing.. we assume that this field cannot be differentiated. - // Could this be confusing from a user perspective? - } - else if (getterResult.isOverloaded()) - { - // Diagnose ambiguous getter. - SLANG_UNIMPLEMENTED_X("Ambiguous differential getters not supported"); - } - else - { - auto getterRefExpr = ConstructLookupResultExpr( - getterResult.item, - declRefExpr, - getterResult.item.declRef.getLoc(), - nullptr); - - // Check that the type is what we expect. - // We're going to do this in a very crude way for now. - // Ideally, we want to use the overload resolution and type - // coercion logic in ResolveInvoke() - // - - auto diffType = _getDifferential(m_astBuilder, checkedTerm->type.type); - auto diffParentType = _getDifferential(m_astBuilder, parentType); - - auto ptrDiffType = m_astBuilder->getPtrType(diffType); - auto inoutContainerDiffType = m_astBuilder->getInOutType(diffParentType); - - auto funcType = as<FuncType>(getterRefExpr->type); - - if (!ptrDiffType->equals(funcType->getResultType())) - { - getSink()->diagnose(getterRefExpr, Diagnostics::typeMismatch, - ptrDiffType, funcType->getResultType()); - } - - if (!inoutContainerDiffType->equals(funcType->getParamType(0))) - { - getSink()->diagnose(getterRefExpr, Diagnostics::typeMismatch, - inoutContainerDiffType, funcType->getParamType(0)); - } - - diffExpr->getterExpr = getterRefExpr; - } - } - - return diffExpr; - } - } - } - - return checkedTerm; - } - Expr* SemanticsVisitor::CheckTerm(Expr* term) { auto checkedTerm = _CheckTerm(term); @@ -920,11 +823,6 @@ namespace Slang this->m_parentFunc->findModifier<JVPDerivativeModifier>()) { maybeRegisterDifferentiableType(getASTBuilder(), checkedTerm->type.type); - - if (auto declRefExpr = as<DeclRefExpr>(checkedTerm)) - { - checkedTerm = maybeMakeDifferentialExpr(checkedTerm); - } } return checkedTerm; @@ -1888,14 +1786,6 @@ namespace Slang return expr; } - Expr* SemanticsExprVisitor::visitDifferentiableDeclRefExpr(DifferentiableDeclRefExpr* expr) - { - auto checkedInnerTerm = CheckTerm(expr->inner); - expr->type = checkedInnerTerm->type; - return expr; - } - - Type* SemanticsVisitor::_toDifferentialParamType(ASTBuilder* builder, Type* primalType) { // Check for type modifiers like 'out' and 'inout'. We need to differentiate the @@ -2729,31 +2619,32 @@ namespace Slang // we can return an overloaded result. if (auto overloadedExpr = as<OverloadedExpr>(baseExpr)) { - if (overloadedExpr->base) + // If a member (dynamic or static) lookup result contains both the actual definition + // and the interface definition obtained from inheritance, we want to filter out + // the interface definitions. + LookupResult filteredLookupResult; + for (auto lookupResult : overloadedExpr->lookupResult2) { - // If a member (dynamic or static) lookup result contains both the actual definition - // and the interface definition obtained from inheritance, we want to filter out - // the interface definitions. - LookupResult filteredLookupResult; - for (auto lookupResult : overloadedExpr->lookupResult2) + bool shouldRemove = false; + if (lookupResult.declRef.getParent().as<InterfaceDecl>()) { - bool shouldRemove = false; - if (lookupResult.declRef.getParent().as<InterfaceDecl>()) - shouldRemove = true; - if (!shouldRemove) - { - filteredLookupResult.items.add(lookupResult); - } + shouldRemove = true; + } + if (lookupResult.declRef.getDecl()->hasModifier<ExtensionExternVarModifier>()) + shouldRemove = true; + if (!shouldRemove) + { + filteredLookupResult.items.add(lookupResult); } - if (filteredLookupResult.items.getCount() == 1) - filteredLookupResult.item = filteredLookupResult.items.getFirst(); - baseExpr = createLookupResultExpr( - overloadedExpr->name, - filteredLookupResult, - overloadedExpr->base, - overloadedExpr->loc, - overloadedExpr); } + if (filteredLookupResult.items.getCount() == 1) + filteredLookupResult.item = filteredLookupResult.items.getFirst(); + baseExpr = createLookupResultExpr( + overloadedExpr->name, + filteredLookupResult, + overloadedExpr->base, + overloadedExpr->loc, + overloadedExpr); // TODO: handle other cases of OverloadedExpr that need filtering. } |
