summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-check-expr.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-10-24 22:19:38 -0700
committerGitHub <noreply@github.com>2022-10-24 22:19:38 -0700
commit41cb7c13e37ec32ffb6557d21da079d77151e136 (patch)
tree38d2c44938e2679c42c5c0e73f5411e59015df93 /source/slang/slang-check-expr.cpp
parent1093218d6f0e114eb9fa52d60ca525bf9dd9f98a (diff)
Rework differentiation of member access through `[DerivativeMember(DiffType.field)]` (#2460)
* wip: remove auto-diff for member access, add diff through property accessors. * Fix getter-setter test. * Fix getter-setter-multi test. * Fix nested-jvp test. * Use [DerivativeMember] attribute to differentiate through member access. * Clean up. * More cleanup. Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-check-expr.cpp')
-rw-r--r--source/slang/slang-check-expr.cpp157
1 files changed, 24 insertions, 133 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 745532c27..29b44e726 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -755,7 +755,7 @@ namespace Slang
}
else if (diffTypeLookupResult.isOverloaded())
{
- SLANG_UNIMPLEMENTED_X("Ambiguous differential type declarations not supported");
+ getSink()->diagnose(declRefType->declRef, Diagnostics::ambiguousReference, getName("Differential"));
}
else
{
@@ -774,7 +774,7 @@ namespace Slang
}
}
- return nullptr;
+ return m_astBuilder->getErrorType();
}
void SemanticsVisitor::maybeRegisterDifferentiableType(ASTBuilder* builder, Type* type)
@@ -813,103 +813,6 @@ namespace Slang
}
}
- Expr* SemanticsVisitor::maybeMakeDifferentialExpr(Expr* checkedTerm)
- {
- // Check that member lookups on differentiable types have appropriate differential
- // getters and setters.
- if (auto declRefExpr = as<DeclRefExpr>(checkedTerm))
- {
-
- // Check if we have a parent container. If yes, then checkedTerm is
- // referencing a member of this parent.
- //
- auto parentType = DeclRefType::create(getASTBuilder(), declRefExpr->declRef.getParent());
-
- // Check if we have an aggregate (i.e. struct-like) type.
- // Ignore interfaces and the case when the term refers to a function
- //
- if (parentType->declRef.as<AggTypeDeclBase>() &&
- !parentType->declRef.as<InterfaceDecl>() &&
- !declRefExpr->declRef.as<CallableDecl>())
- {
- // Check if the parent container type is differentiable.
- if (auto parentDiffWitness = as<SubtypeWitness>(
- tryGetInterfaceConformanceWitness(
- parentType, getASTBuilder()->getDifferentiableInterface())))
- {
- // If yes, the member in checkedTerm should have a differential getter and setter.
- // Otherwise, <ERROR>
- //
- auto diffExpr = m_astBuilder->create<DifferentiableDeclRefExpr>();
- diffExpr->type = checkedTerm->type;
- diffExpr->inner = checkedTerm;
-
- {
- auto getterName = getName("__getDifferentialFor_" + declRefExpr->name->text);
- auto getterResult = lookUpMember(
- getASTBuilder(),
- this,
- getterName,
- parentType,
- Slang::LookupMask::Function,
- Slang::LookupOptions::None);
-
- if (!getterResult.isValid())
- {
- // Do nothing.. we assume that this field cannot be differentiated.
- // Could this be confusing from a user perspective?
- }
- else if (getterResult.isOverloaded())
- {
- // Diagnose ambiguous getter.
- SLANG_UNIMPLEMENTED_X("Ambiguous differential getters not supported");
- }
- else
- {
- auto getterRefExpr = ConstructLookupResultExpr(
- getterResult.item,
- declRefExpr,
- getterResult.item.declRef.getLoc(),
- nullptr);
-
- // Check that the type is what we expect.
- // We're going to do this in a very crude way for now.
- // Ideally, we want to use the overload resolution and type
- // coercion logic in ResolveInvoke()
- //
-
- auto diffType = _getDifferential(m_astBuilder, checkedTerm->type.type);
- auto diffParentType = _getDifferential(m_astBuilder, parentType);
-
- auto ptrDiffType = m_astBuilder->getPtrType(diffType);
- auto inoutContainerDiffType = m_astBuilder->getInOutType(diffParentType);
-
- auto funcType = as<FuncType>(getterRefExpr->type);
-
- if (!ptrDiffType->equals(funcType->getResultType()))
- {
- getSink()->diagnose(getterRefExpr, Diagnostics::typeMismatch,
- ptrDiffType, funcType->getResultType());
- }
-
- if (!inoutContainerDiffType->equals(funcType->getParamType(0)))
- {
- getSink()->diagnose(getterRefExpr, Diagnostics::typeMismatch,
- inoutContainerDiffType, funcType->getParamType(0));
- }
-
- diffExpr->getterExpr = getterRefExpr;
- }
- }
-
- return diffExpr;
- }
- }
- }
-
- return checkedTerm;
- }
-
Expr* SemanticsVisitor::CheckTerm(Expr* term)
{
auto checkedTerm = _CheckTerm(term);
@@ -920,11 +823,6 @@ namespace Slang
this->m_parentFunc->findModifier<JVPDerivativeModifier>())
{
maybeRegisterDifferentiableType(getASTBuilder(), checkedTerm->type.type);
-
- if (auto declRefExpr = as<DeclRefExpr>(checkedTerm))
- {
- checkedTerm = maybeMakeDifferentialExpr(checkedTerm);
- }
}
return checkedTerm;
@@ -1888,14 +1786,6 @@ namespace Slang
return expr;
}
- Expr* SemanticsExprVisitor::visitDifferentiableDeclRefExpr(DifferentiableDeclRefExpr* expr)
- {
- auto checkedInnerTerm = CheckTerm(expr->inner);
- expr->type = checkedInnerTerm->type;
- return expr;
- }
-
-
Type* SemanticsVisitor::_toDifferentialParamType(ASTBuilder* builder, Type* primalType)
{
// Check for type modifiers like 'out' and 'inout'. We need to differentiate the
@@ -2729,31 +2619,32 @@ namespace Slang
// we can return an overloaded result.
if (auto overloadedExpr = as<OverloadedExpr>(baseExpr))
{
- if (overloadedExpr->base)
+ // If a member (dynamic or static) lookup result contains both the actual definition
+ // and the interface definition obtained from inheritance, we want to filter out
+ // the interface definitions.
+ LookupResult filteredLookupResult;
+ for (auto lookupResult : overloadedExpr->lookupResult2)
{
- // If a member (dynamic or static) lookup result contains both the actual definition
- // and the interface definition obtained from inheritance, we want to filter out
- // the interface definitions.
- LookupResult filteredLookupResult;
- for (auto lookupResult : overloadedExpr->lookupResult2)
+ bool shouldRemove = false;
+ if (lookupResult.declRef.getParent().as<InterfaceDecl>())
{
- bool shouldRemove = false;
- if (lookupResult.declRef.getParent().as<InterfaceDecl>())
- shouldRemove = true;
- if (!shouldRemove)
- {
- filteredLookupResult.items.add(lookupResult);
- }
+ shouldRemove = true;
+ }
+ if (lookupResult.declRef.getDecl()->hasModifier<ExtensionExternVarModifier>())
+ shouldRemove = true;
+ if (!shouldRemove)
+ {
+ filteredLookupResult.items.add(lookupResult);
}
- if (filteredLookupResult.items.getCount() == 1)
- filteredLookupResult.item = filteredLookupResult.items.getFirst();
- baseExpr = createLookupResultExpr(
- overloadedExpr->name,
- filteredLookupResult,
- overloadedExpr->base,
- overloadedExpr->loc,
- overloadedExpr);
}
+ if (filteredLookupResult.items.getCount() == 1)
+ filteredLookupResult.item = filteredLookupResult.items.getFirst();
+ baseExpr = createLookupResultExpr(
+ overloadedExpr->name,
+ filteredLookupResult,
+ overloadedExpr->base,
+ overloadedExpr->loc,
+ overloadedExpr);
// TODO: handle other cases of OverloadedExpr that need filtering.
}