summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-check-expr.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-check-expr.cpp')
-rw-r--r--source/slang/slang-check-expr.cpp28
1 files changed, 21 insertions, 7 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index b730069b6..2f91a6a77 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -1405,14 +1405,12 @@ Type* SemanticsVisitor::getDifferentialType(ASTBuilder* builder, Type* type, Sou
return result;
}
-void SemanticsVisitor::addDifferentiableTypeToDiffTypeRegistry(
- DeclRefType* type,
- SubtypeWitness* witness)
+void SemanticsVisitor::addDifferentiableTypeToDiffTypeRegistry(Type* type, SubtypeWitness* witness)
{
SLANG_RELEASE_ASSERT(m_parentDifferentiableAttr);
if (witness)
{
- m_parentDifferentiableAttr->addType(type->getDeclRef(), witness);
+ m_parentDifferentiableAttr->addType(type, witness);
}
}
@@ -1468,14 +1466,14 @@ void SemanticsVisitor::maybeRegisterDifferentiableTypeImplRecursive(ASTBuilder*
type,
getASTBuilder()->getDifferentiableInterfaceType())))
{
- addDifferentiableTypeToDiffTypeRegistry((DeclRefType*)type, subtypeWitness);
+ addDifferentiableTypeToDiffTypeRegistry(type, subtypeWitness);
}
if (auto subtypeWitness = as<SubtypeWitness>(tryGetInterfaceConformanceWitness(
type,
getASTBuilder()->getDifferentiableRefInterfaceType())))
{
- addDifferentiableTypeToDiffTypeRegistry((DeclRefType*)type, subtypeWitness);
+ addDifferentiableTypeToDiffTypeRegistry(type, subtypeWitness);
}
if (auto aggTypeDeclRef = declRefType->getDeclRef().as<AggTypeDecl>())
@@ -1515,6 +1513,15 @@ void SemanticsVisitor::maybeRegisterDifferentiableTypeImplRecursive(ASTBuilder*
maybeRegisterDifferentiableTypeImplRecursive(builder, typePack->getElementType(i));
return;
}
+
+ // General check for types that may not be decl-ref-type, but still have some conformance to
+ // IDifferentiable/IDifferentiablePtrType
+ if (auto subtypeWitness = as<SubtypeWitness>(tryGetInterfaceConformanceWitness(
+ type,
+ getASTBuilder()->getDifferentiableInterfaceType())))
+ {
+ addDifferentiableTypeToDiffTypeRegistry(type, subtypeWitness);
+ }
}
@@ -4846,7 +4853,14 @@ Expr* SemanticsVisitor::checkBaseForMemberExpr(
auto baseExpr = inBaseExpr;
baseExpr = CheckTerm(baseExpr);
- return maybeInsertImplicitOpForMemberBase(baseExpr, checkBaseContext, outNeedDeref);
+ auto resultBaseExpr =
+ maybeInsertImplicitOpForMemberBase(baseExpr, checkBaseContext, outNeedDeref);
+
+ // We might want to register differentiability on any implicit ops that we add in.
+ if (this->m_parentFunc && this->m_parentFunc->findModifier<DifferentiableAttribute>())
+ maybeRegisterDifferentiableType(getASTBuilder(), resultBaseExpr->type.type);
+
+ return resultBaseExpr;
}
Expr* SemanticsVisitor::checkGeneralMemberLookupExpr(MemberExpr* expr, Type* baseType)