From 41cb7c13e37ec32ffb6557d21da079d77151e136 Mon Sep 17 00:00:00 2001 From: Yong He Date: Mon, 24 Oct 2022 22:19:38 -0700 Subject: Rework differentiation of member access through `[DerivativeMember(DiffType.field)]` (#2460) * wip: remove auto-diff for member access, add diff through property accessors. * Fix getter-setter test. * Fix getter-setter-multi test. * Fix nested-jvp test. * Use [DerivativeMember] attribute to differentiate through member access. * Clean up. * More cleanup. Co-authored-by: Yong He --- source/slang/slang-check-expr.cpp | 157 ++++++-------------------------------- 1 file changed, 24 insertions(+), 133 deletions(-) (limited to 'source/slang/slang-check-expr.cpp') diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 745532c27..29b44e726 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -755,7 +755,7 @@ namespace Slang } else if (diffTypeLookupResult.isOverloaded()) { - SLANG_UNIMPLEMENTED_X("Ambiguous differential type declarations not supported"); + getSink()->diagnose(declRefType->declRef, Diagnostics::ambiguousReference, getName("Differential")); } else { @@ -774,7 +774,7 @@ namespace Slang } } - return nullptr; + return m_astBuilder->getErrorType(); } void SemanticsVisitor::maybeRegisterDifferentiableType(ASTBuilder* builder, Type* type) @@ -813,103 +813,6 @@ namespace Slang } } - Expr* SemanticsVisitor::maybeMakeDifferentialExpr(Expr* checkedTerm) - { - // Check that member lookups on differentiable types have appropriate differential - // getters and setters. - if (auto declRefExpr = as(checkedTerm)) - { - - // Check if we have a parent container. If yes, then checkedTerm is - // referencing a member of this parent. - // - auto parentType = DeclRefType::create(getASTBuilder(), declRefExpr->declRef.getParent()); - - // Check if we have an aggregate (i.e. struct-like) type. - // Ignore interfaces and the case when the term refers to a function - // - if (parentType->declRef.as() && - !parentType->declRef.as() && - !declRefExpr->declRef.as()) - { - // Check if the parent container type is differentiable. - if (auto parentDiffWitness = as( - tryGetInterfaceConformanceWitness( - parentType, getASTBuilder()->getDifferentiableInterface()))) - { - // If yes, the member in checkedTerm should have a differential getter and setter. - // Otherwise, - // - auto diffExpr = m_astBuilder->create(); - diffExpr->type = checkedTerm->type; - diffExpr->inner = checkedTerm; - - { - auto getterName = getName("__getDifferentialFor_" + declRefExpr->name->text); - auto getterResult = lookUpMember( - getASTBuilder(), - this, - getterName, - parentType, - Slang::LookupMask::Function, - Slang::LookupOptions::None); - - if (!getterResult.isValid()) - { - // Do nothing.. we assume that this field cannot be differentiated. - // Could this be confusing from a user perspective? - } - else if (getterResult.isOverloaded()) - { - // Diagnose ambiguous getter. - SLANG_UNIMPLEMENTED_X("Ambiguous differential getters not supported"); - } - else - { - auto getterRefExpr = ConstructLookupResultExpr( - getterResult.item, - declRefExpr, - getterResult.item.declRef.getLoc(), - nullptr); - - // Check that the type is what we expect. - // We're going to do this in a very crude way for now. - // Ideally, we want to use the overload resolution and type - // coercion logic in ResolveInvoke() - // - - auto diffType = _getDifferential(m_astBuilder, checkedTerm->type.type); - auto diffParentType = _getDifferential(m_astBuilder, parentType); - - auto ptrDiffType = m_astBuilder->getPtrType(diffType); - auto inoutContainerDiffType = m_astBuilder->getInOutType(diffParentType); - - auto funcType = as(getterRefExpr->type); - - if (!ptrDiffType->equals(funcType->getResultType())) - { - getSink()->diagnose(getterRefExpr, Diagnostics::typeMismatch, - ptrDiffType, funcType->getResultType()); - } - - if (!inoutContainerDiffType->equals(funcType->getParamType(0))) - { - getSink()->diagnose(getterRefExpr, Diagnostics::typeMismatch, - inoutContainerDiffType, funcType->getParamType(0)); - } - - diffExpr->getterExpr = getterRefExpr; - } - } - - return diffExpr; - } - } - } - - return checkedTerm; - } - Expr* SemanticsVisitor::CheckTerm(Expr* term) { auto checkedTerm = _CheckTerm(term); @@ -920,11 +823,6 @@ namespace Slang this->m_parentFunc->findModifier()) { maybeRegisterDifferentiableType(getASTBuilder(), checkedTerm->type.type); - - if (auto declRefExpr = as(checkedTerm)) - { - checkedTerm = maybeMakeDifferentialExpr(checkedTerm); - } } return checkedTerm; @@ -1888,14 +1786,6 @@ namespace Slang return expr; } - Expr* SemanticsExprVisitor::visitDifferentiableDeclRefExpr(DifferentiableDeclRefExpr* expr) - { - auto checkedInnerTerm = CheckTerm(expr->inner); - expr->type = checkedInnerTerm->type; - return expr; - } - - Type* SemanticsVisitor::_toDifferentialParamType(ASTBuilder* builder, Type* primalType) { // Check for type modifiers like 'out' and 'inout'. We need to differentiate the @@ -2729,31 +2619,32 @@ namespace Slang // we can return an overloaded result. if (auto overloadedExpr = as(baseExpr)) { - if (overloadedExpr->base) + // If a member (dynamic or static) lookup result contains both the actual definition + // and the interface definition obtained from inheritance, we want to filter out + // the interface definitions. + LookupResult filteredLookupResult; + for (auto lookupResult : overloadedExpr->lookupResult2) { - // If a member (dynamic or static) lookup result contains both the actual definition - // and the interface definition obtained from inheritance, we want to filter out - // the interface definitions. - LookupResult filteredLookupResult; - for (auto lookupResult : overloadedExpr->lookupResult2) + bool shouldRemove = false; + if (lookupResult.declRef.getParent().as()) { - bool shouldRemove = false; - if (lookupResult.declRef.getParent().as()) - shouldRemove = true; - if (!shouldRemove) - { - filteredLookupResult.items.add(lookupResult); - } + shouldRemove = true; + } + if (lookupResult.declRef.getDecl()->hasModifier()) + shouldRemove = true; + if (!shouldRemove) + { + filteredLookupResult.items.add(lookupResult); } - if (filteredLookupResult.items.getCount() == 1) - filteredLookupResult.item = filteredLookupResult.items.getFirst(); - baseExpr = createLookupResultExpr( - overloadedExpr->name, - filteredLookupResult, - overloadedExpr->base, - overloadedExpr->loc, - overloadedExpr); } + if (filteredLookupResult.items.getCount() == 1) + filteredLookupResult.item = filteredLookupResult.items.getFirst(); + baseExpr = createLookupResultExpr( + overloadedExpr->name, + filteredLookupResult, + overloadedExpr->base, + overloadedExpr->loc, + overloadedExpr); // TODO: handle other cases of OverloadedExpr that need filtering. } -- cgit v1.2.3