summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-check-expr.cpp
diff options
context:
space:
mode:
authorSai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com>2022-10-20 14:22:00 -0400
committerGitHub <noreply@github.com>2022-10-20 11:22:00 -0700
commit1093218d6f0e114eb9fa52d60ca525bf9dd9f98a (patch)
treee85158637680f783caaf7f4433a6844398cd8f7b /source/slang/slang-check-expr.cpp
parent576c8407e60143682cd40c68101c6eae8563ca3d (diff)
Modified the new type system to support generic differentiable types … (#2413)
* Modified the new type system to support generic differentiable types and added support for differentiating overloaded functions. * Changed a few asserts to release asserts to avoid unreferenced variable errors * Fixed a naming issue with TypeWitnessBreadcumb::Flavor::Decl * Added logic to avoid tracking differentiable types if the module does not use auto-diff or define differentiable types. * Moved the auto-diff passes to after the specialization step, added a more complex generics test * Added a generics stress test and fixed AST-side logic. IR side needs some more work * Added differential getter and setter logic, fixed multiple issues with DifferentiableTypeDictionary, added support for loops and conditions * Changed differential getters to use pointer types, added getter type checking * Fixed some bugs related to diff type registration and differential getters * Removed some superfluous code * Removed some more unused code. * Fixed an issue with witness substitution * Minor fix Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source/slang/slang-check-expr.cpp')
-rw-r--r--source/slang/slang-check-expr.cpp278
1 files changed, 243 insertions, 35 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index f1ccddf15..745532c27 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -719,8 +719,219 @@ namespace Slang
return _resolveOverloadedExprImpl(overloadedExpr, mask, getSink());
}
+ Type* SemanticsVisitor::_getDifferential(ASTBuilder* builder, Type* type)
+ {
+ if (auto ptrType = as<PtrTypeBase>(type))
+ {
+ return builder->getPtrType(
+ _getDifferential(builder, ptrType->getValueType()),
+ ptrType->getClassInfo().m_name);
+ }
+ else if (auto arrayType = as<ArrayExpressionType>(type))
+ {
+ return builder->getArrayType(
+ _getDifferential(builder, arrayType->baseType),
+ arrayType->arrayLength);
+ }
+
+ if (auto declRefType = as<DeclRefType>(type))
+ {
+ if (auto witness = as<SubtypeWitness>(tryGetInterfaceConformanceWitness(type, builder->getDifferentiableInterface())))
+ {
+ auto diffTypeLookupResult = lookUpMember(
+ getASTBuilder(),
+ this,
+ getName("Differential"),
+ type,
+ Slang::LookupMask::type,
+ Slang::LookupOptions::None);
+
+ diffTypeLookupResult = resolveOverloadedLookup(diffTypeLookupResult);
+
+ if (!diffTypeLookupResult.isValid())
+ {
+ // Diagnose no 'Differential' member.
+ getSink()->diagnose(declRefType->declRef, Diagnostics::typeDoesntImplementInterfaceRequirement, type, getName("Differential"));
+ }
+ else if (diffTypeLookupResult.isOverloaded())
+ {
+ SLANG_UNIMPLEMENTED_X("Ambiguous differential type declarations not supported");
+ }
+ else
+ {
+ SharedTypeExpr* baseTypeExpr = m_astBuilder->create<SharedTypeExpr>();
+ baseTypeExpr->base.type = type;
+ baseTypeExpr->type.type = m_astBuilder->getTypeType(type);
+
+ auto diffTypeExpr = ConstructLookupResultExpr(
+ diffTypeLookupResult.item,
+ baseTypeExpr,
+ declRefType->declRef.getLoc(),
+ baseTypeExpr);
+
+ return ExtractTypeFromTypeRepr(diffTypeExpr);
+ }
+ }
+ }
+
+ return nullptr;
+ }
+
+ void SemanticsVisitor::maybeRegisterDifferentiableType(ASTBuilder* builder, Type* type)
+ {
+ if (!builder->isDifferentiableInterfaceAvailable())
+ {
+ return;
+ }
+
+ // Check for special cases such as PtrTypeBase<T> or Array<T>
+ // This could potentially be handled later by simply defining extensions
+ // for Ptr<T:IDifferentiable> etc..
+ //
+ if (auto ptrType = as<PtrTypeBase>(type))
+ {
+ maybeRegisterDifferentiableType(builder, ptrType->getValueType());
+ return;
+ }
+
+ if (auto arrayType = as<ArrayExpressionType>(type))
+ {
+ maybeRegisterDifferentiableType(builder, arrayType->baseType);
+ return;
+ }
+
+ if (auto declRefType = as<DeclRefType>(type))
+ {
+ if (auto subtypeWitness = as<SubtypeWitness>(
+ tryGetInterfaceConformanceWitness(type, getASTBuilder()->getDifferentiableInterface())))
+ {
+ auto diffTypeContext = this->getShared()->innermostDiffTypeContext();
+ diffTypeContext->registerDifferentiableType((DeclRefType*)type, subtypeWitness);
+ }
+
+ return;
+ }
+ }
+
+ Expr* SemanticsVisitor::maybeMakeDifferentialExpr(Expr* checkedTerm)
+ {
+ // Check that member lookups on differentiable types have appropriate differential
+ // getters and setters.
+ if (auto declRefExpr = as<DeclRefExpr>(checkedTerm))
+ {
+
+ // Check if we have a parent container. If yes, then checkedTerm is
+ // referencing a member of this parent.
+ //
+ auto parentType = DeclRefType::create(getASTBuilder(), declRefExpr->declRef.getParent());
+
+ // Check if we have an aggregate (i.e. struct-like) type.
+ // Ignore interfaces and the case when the term refers to a function
+ //
+ if (parentType->declRef.as<AggTypeDeclBase>() &&
+ !parentType->declRef.as<InterfaceDecl>() &&
+ !declRefExpr->declRef.as<CallableDecl>())
+ {
+ // Check if the parent container type is differentiable.
+ if (auto parentDiffWitness = as<SubtypeWitness>(
+ tryGetInterfaceConformanceWitness(
+ parentType, getASTBuilder()->getDifferentiableInterface())))
+ {
+ // If yes, the member in checkedTerm should have a differential getter and setter.
+ // Otherwise, <ERROR>
+ //
+ auto diffExpr = m_astBuilder->create<DifferentiableDeclRefExpr>();
+ diffExpr->type = checkedTerm->type;
+ diffExpr->inner = checkedTerm;
+
+ {
+ auto getterName = getName("__getDifferentialFor_" + declRefExpr->name->text);
+ auto getterResult = lookUpMember(
+ getASTBuilder(),
+ this,
+ getterName,
+ parentType,
+ Slang::LookupMask::Function,
+ Slang::LookupOptions::None);
+
+ if (!getterResult.isValid())
+ {
+ // Do nothing.. we assume that this field cannot be differentiated.
+ // Could this be confusing from a user perspective?
+ }
+ else if (getterResult.isOverloaded())
+ {
+ // Diagnose ambiguous getter.
+ SLANG_UNIMPLEMENTED_X("Ambiguous differential getters not supported");
+ }
+ else
+ {
+ auto getterRefExpr = ConstructLookupResultExpr(
+ getterResult.item,
+ declRefExpr,
+ getterResult.item.declRef.getLoc(),
+ nullptr);
+
+ // Check that the type is what we expect.
+ // We're going to do this in a very crude way for now.
+ // Ideally, we want to use the overload resolution and type
+ // coercion logic in ResolveInvoke()
+ //
+
+ auto diffType = _getDifferential(m_astBuilder, checkedTerm->type.type);
+ auto diffParentType = _getDifferential(m_astBuilder, parentType);
+
+ auto ptrDiffType = m_astBuilder->getPtrType(diffType);
+ auto inoutContainerDiffType = m_astBuilder->getInOutType(diffParentType);
+
+ auto funcType = as<FuncType>(getterRefExpr->type);
+
+ if (!ptrDiffType->equals(funcType->getResultType()))
+ {
+ getSink()->diagnose(getterRefExpr, Diagnostics::typeMismatch,
+ ptrDiffType, funcType->getResultType());
+ }
+
+ if (!inoutContainerDiffType->equals(funcType->getParamType(0)))
+ {
+ getSink()->diagnose(getterRefExpr, Diagnostics::typeMismatch,
+ inoutContainerDiffType, funcType->getParamType(0));
+ }
+
+ diffExpr->getterExpr = getterRefExpr;
+ }
+ }
+
+ return diffExpr;
+ }
+ }
+ }
+
+ return checkedTerm;
+ }
+
Expr* SemanticsVisitor::CheckTerm(Expr* term)
{
+ auto checkedTerm = _CheckTerm(term);
+
+ // Differentiable type checking.
+ // TODO: This can be super slow.
+ if (this->m_parentFunc &&
+ this->m_parentFunc->findModifier<JVPDerivativeModifier>())
+ {
+ maybeRegisterDifferentiableType(getASTBuilder(), checkedTerm->type.type);
+
+ if (auto declRefExpr = as<DeclRefExpr>(checkedTerm))
+ {
+ checkedTerm = maybeMakeDifferentialExpr(checkedTerm);
+ }
+ }
+
+ return checkedTerm;
+ }
+
+ Expr* SemanticsVisitor::_CheckTerm(Expr* term)
+ {
if (!term) return nullptr;
// The process of checking a term/expression can end up introducing
@@ -1677,6 +1888,13 @@ namespace Slang
return expr;
}
+ Expr* SemanticsExprVisitor::visitDifferentiableDeclRefExpr(DifferentiableDeclRefExpr* expr)
+ {
+ auto checkedInnerTerm = CheckTerm(expr->inner);
+ expr->type = checkedInnerTerm->type;
+ return expr;
+ }
+
Type* SemanticsVisitor::_toDifferentialParamType(ASTBuilder* builder, Type* primalType)
{
@@ -1715,48 +1933,38 @@ namespace Slang
return primalType;
}
- Expr* SemanticsExprVisitor::visitJVPDifferentiateExpr(JVPDifferentiateExpr* expr)
+ Type* SemanticsVisitor::processJVPFuncType(ASTBuilder* builder, FuncType* originalType)
{
- // Check/Resolve inner function declaration.
- expr->baseFunction = CheckTerm(expr->baseFunction);
+ // Resolve JVP type here.
+ // Note that this type checking needs to be in sync with
+ // the auto-generation logic in slang-ir-jvp-diff.cpp
- auto astBuilder = this->getASTBuilder();
+ FuncType* jvpType = builder->create<FuncType>();
- if(auto primalType = as<FuncType>(expr->baseFunction->type))
- {
- // Resolve JVP type here.
- // Note that this type checking needs to be in sync with
- // the auto-generation logic in slang-ir-jvp-diff.cpp
-
- FuncType* jvpType = astBuilder->create<FuncType>();
-
- // The JVP return type is float if primal return type is float
- // void otherwise.
- //
- jvpType->resultType = _toJVPReturnType(astBuilder, primalType->getResultType());
-
- // No support for differentiating function that throw errors, for now.
- SLANG_ASSERT(primalType->errorType->equals(astBuilder->getBottomType()));
- jvpType->errorType = primalType->errorType;
-
- for (UInt i = 0; i < primalType->getParamCount(); i++)
- {
- if(auto jvpParamType = _toDifferentialParamType(astBuilder, primalType->getParamType(i)))
- jvpType->paramTypes.add(jvpParamType);
- }
+ // The JVP return type is float if primal return type is float
+ // void otherwise.
+ //
+ jvpType->resultType = _toJVPReturnType(builder, originalType->getResultType());
+
+ // No support for differentiating function that throw errors, for now.
+ SLANG_ASSERT(originalType->errorType->equals(builder->getBottomType()));
+ jvpType->errorType = originalType->errorType;
- expr->type = jvpType;
- }
- else
+ for (UInt i = 0; i < originalType->getParamCount(); i++)
{
- // Error
- expr->type = astBuilder->getErrorType();
- if (!as<ErrorType>(expr->baseFunction->type))
- {
- getSink()->diagnose(expr->baseFunction->loc, Diagnostics::expectedFunction, expr->baseFunction->type);
- }
+ if(auto jvpParamType = _toDifferentialParamType(builder, originalType->getParamType(i)))
+ jvpType->paramTypes.add(jvpParamType);
}
+ return jvpType;
+ }
+
+ Expr* SemanticsExprVisitor::visitJVPDifferentiateExpr(JVPDifferentiateExpr* expr)
+ {
+ this->getShared()->getDiffTypeContext()->requireDifferentiableTypeDictionary();
+
+ // Check/Resolve inner function declaration.
+ expr->baseFunction = CheckTerm(expr->baseFunction);
return expr;
}