diff options
| author | Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> | 2022-10-20 14:22:00 -0400 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-10-20 11:22:00 -0700 |
| commit | 1093218d6f0e114eb9fa52d60ca525bf9dd9f98a (patch) | |
| tree | e85158637680f783caaf7f4433a6844398cd8f7b /source/slang/slang-check-expr.cpp | |
| parent | 576c8407e60143682cd40c68101c6eae8563ca3d (diff) | |
Modified the new type system to support generic differentiable types … (#2413)
* Modified the new type system to support generic differentiable types and added support for differentiating overloaded functions.
* Changed a few asserts to release asserts to avoid unreferenced variable errors
* Fixed a naming issue with TypeWitnessBreadcumb::Flavor::Decl
* Added logic to avoid tracking differentiable types if the module does not use auto-diff or define differentiable types.
* Moved the auto-diff passes to after the specialization step, added a more complex generics test
* Added a generics stress test and fixed AST-side logic. IR side needs some more work
* Added differential getter and setter logic, fixed multiple issues with DifferentiableTypeDictionary, added support for loops and conditions
* Changed differential getters to use pointer types, added getter type checking
* Fixed some bugs related to diff type registration and differential getters
* Removed some superfluous code
* Removed some more unused code.
* Fixed an issue with witness substitution
* Minor fix
Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source/slang/slang-check-expr.cpp')
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 278 |
1 files changed, 243 insertions, 35 deletions
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index f1ccddf15..745532c27 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -719,8 +719,219 @@ namespace Slang return _resolveOverloadedExprImpl(overloadedExpr, mask, getSink()); } + Type* SemanticsVisitor::_getDifferential(ASTBuilder* builder, Type* type) + { + if (auto ptrType = as<PtrTypeBase>(type)) + { + return builder->getPtrType( + _getDifferential(builder, ptrType->getValueType()), + ptrType->getClassInfo().m_name); + } + else if (auto arrayType = as<ArrayExpressionType>(type)) + { + return builder->getArrayType( + _getDifferential(builder, arrayType->baseType), + arrayType->arrayLength); + } + + if (auto declRefType = as<DeclRefType>(type)) + { + if (auto witness = as<SubtypeWitness>(tryGetInterfaceConformanceWitness(type, builder->getDifferentiableInterface()))) + { + auto diffTypeLookupResult = lookUpMember( + getASTBuilder(), + this, + getName("Differential"), + type, + Slang::LookupMask::type, + Slang::LookupOptions::None); + + diffTypeLookupResult = resolveOverloadedLookup(diffTypeLookupResult); + + if (!diffTypeLookupResult.isValid()) + { + // Diagnose no 'Differential' member. + getSink()->diagnose(declRefType->declRef, Diagnostics::typeDoesntImplementInterfaceRequirement, type, getName("Differential")); + } + else if (diffTypeLookupResult.isOverloaded()) + { + SLANG_UNIMPLEMENTED_X("Ambiguous differential type declarations not supported"); + } + else + { + SharedTypeExpr* baseTypeExpr = m_astBuilder->create<SharedTypeExpr>(); + baseTypeExpr->base.type = type; + baseTypeExpr->type.type = m_astBuilder->getTypeType(type); + + auto diffTypeExpr = ConstructLookupResultExpr( + diffTypeLookupResult.item, + baseTypeExpr, + declRefType->declRef.getLoc(), + baseTypeExpr); + + return ExtractTypeFromTypeRepr(diffTypeExpr); + } + } + } + + return nullptr; + } + + void SemanticsVisitor::maybeRegisterDifferentiableType(ASTBuilder* builder, Type* type) + { + if (!builder->isDifferentiableInterfaceAvailable()) + { + return; + } + + // Check for special cases such as PtrTypeBase<T> or Array<T> + // This could potentially be handled later by simply defining extensions + // for Ptr<T:IDifferentiable> etc.. + // + if (auto ptrType = as<PtrTypeBase>(type)) + { + maybeRegisterDifferentiableType(builder, ptrType->getValueType()); + return; + } + + if (auto arrayType = as<ArrayExpressionType>(type)) + { + maybeRegisterDifferentiableType(builder, arrayType->baseType); + return; + } + + if (auto declRefType = as<DeclRefType>(type)) + { + if (auto subtypeWitness = as<SubtypeWitness>( + tryGetInterfaceConformanceWitness(type, getASTBuilder()->getDifferentiableInterface()))) + { + auto diffTypeContext = this->getShared()->innermostDiffTypeContext(); + diffTypeContext->registerDifferentiableType((DeclRefType*)type, subtypeWitness); + } + + return; + } + } + + Expr* SemanticsVisitor::maybeMakeDifferentialExpr(Expr* checkedTerm) + { + // Check that member lookups on differentiable types have appropriate differential + // getters and setters. + if (auto declRefExpr = as<DeclRefExpr>(checkedTerm)) + { + + // Check if we have a parent container. If yes, then checkedTerm is + // referencing a member of this parent. + // + auto parentType = DeclRefType::create(getASTBuilder(), declRefExpr->declRef.getParent()); + + // Check if we have an aggregate (i.e. struct-like) type. + // Ignore interfaces and the case when the term refers to a function + // + if (parentType->declRef.as<AggTypeDeclBase>() && + !parentType->declRef.as<InterfaceDecl>() && + !declRefExpr->declRef.as<CallableDecl>()) + { + // Check if the parent container type is differentiable. + if (auto parentDiffWitness = as<SubtypeWitness>( + tryGetInterfaceConformanceWitness( + parentType, getASTBuilder()->getDifferentiableInterface()))) + { + // If yes, the member in checkedTerm should have a differential getter and setter. + // Otherwise, <ERROR> + // + auto diffExpr = m_astBuilder->create<DifferentiableDeclRefExpr>(); + diffExpr->type = checkedTerm->type; + diffExpr->inner = checkedTerm; + + { + auto getterName = getName("__getDifferentialFor_" + declRefExpr->name->text); + auto getterResult = lookUpMember( + getASTBuilder(), + this, + getterName, + parentType, + Slang::LookupMask::Function, + Slang::LookupOptions::None); + + if (!getterResult.isValid()) + { + // Do nothing.. we assume that this field cannot be differentiated. + // Could this be confusing from a user perspective? + } + else if (getterResult.isOverloaded()) + { + // Diagnose ambiguous getter. + SLANG_UNIMPLEMENTED_X("Ambiguous differential getters not supported"); + } + else + { + auto getterRefExpr = ConstructLookupResultExpr( + getterResult.item, + declRefExpr, + getterResult.item.declRef.getLoc(), + nullptr); + + // Check that the type is what we expect. + // We're going to do this in a very crude way for now. + // Ideally, we want to use the overload resolution and type + // coercion logic in ResolveInvoke() + // + + auto diffType = _getDifferential(m_astBuilder, checkedTerm->type.type); + auto diffParentType = _getDifferential(m_astBuilder, parentType); + + auto ptrDiffType = m_astBuilder->getPtrType(diffType); + auto inoutContainerDiffType = m_astBuilder->getInOutType(diffParentType); + + auto funcType = as<FuncType>(getterRefExpr->type); + + if (!ptrDiffType->equals(funcType->getResultType())) + { + getSink()->diagnose(getterRefExpr, Diagnostics::typeMismatch, + ptrDiffType, funcType->getResultType()); + } + + if (!inoutContainerDiffType->equals(funcType->getParamType(0))) + { + getSink()->diagnose(getterRefExpr, Diagnostics::typeMismatch, + inoutContainerDiffType, funcType->getParamType(0)); + } + + diffExpr->getterExpr = getterRefExpr; + } + } + + return diffExpr; + } + } + } + + return checkedTerm; + } + Expr* SemanticsVisitor::CheckTerm(Expr* term) { + auto checkedTerm = _CheckTerm(term); + + // Differentiable type checking. + // TODO: This can be super slow. + if (this->m_parentFunc && + this->m_parentFunc->findModifier<JVPDerivativeModifier>()) + { + maybeRegisterDifferentiableType(getASTBuilder(), checkedTerm->type.type); + + if (auto declRefExpr = as<DeclRefExpr>(checkedTerm)) + { + checkedTerm = maybeMakeDifferentialExpr(checkedTerm); + } + } + + return checkedTerm; + } + + Expr* SemanticsVisitor::_CheckTerm(Expr* term) + { if (!term) return nullptr; // The process of checking a term/expression can end up introducing @@ -1677,6 +1888,13 @@ namespace Slang return expr; } + Expr* SemanticsExprVisitor::visitDifferentiableDeclRefExpr(DifferentiableDeclRefExpr* expr) + { + auto checkedInnerTerm = CheckTerm(expr->inner); + expr->type = checkedInnerTerm->type; + return expr; + } + Type* SemanticsVisitor::_toDifferentialParamType(ASTBuilder* builder, Type* primalType) { @@ -1715,48 +1933,38 @@ namespace Slang return primalType; } - Expr* SemanticsExprVisitor::visitJVPDifferentiateExpr(JVPDifferentiateExpr* expr) + Type* SemanticsVisitor::processJVPFuncType(ASTBuilder* builder, FuncType* originalType) { - // Check/Resolve inner function declaration. - expr->baseFunction = CheckTerm(expr->baseFunction); + // Resolve JVP type here. + // Note that this type checking needs to be in sync with + // the auto-generation logic in slang-ir-jvp-diff.cpp - auto astBuilder = this->getASTBuilder(); + FuncType* jvpType = builder->create<FuncType>(); - if(auto primalType = as<FuncType>(expr->baseFunction->type)) - { - // Resolve JVP type here. - // Note that this type checking needs to be in sync with - // the auto-generation logic in slang-ir-jvp-diff.cpp - - FuncType* jvpType = astBuilder->create<FuncType>(); - - // The JVP return type is float if primal return type is float - // void otherwise. - // - jvpType->resultType = _toJVPReturnType(astBuilder, primalType->getResultType()); - - // No support for differentiating function that throw errors, for now. - SLANG_ASSERT(primalType->errorType->equals(astBuilder->getBottomType())); - jvpType->errorType = primalType->errorType; - - for (UInt i = 0; i < primalType->getParamCount(); i++) - { - if(auto jvpParamType = _toDifferentialParamType(astBuilder, primalType->getParamType(i))) - jvpType->paramTypes.add(jvpParamType); - } + // The JVP return type is float if primal return type is float + // void otherwise. + // + jvpType->resultType = _toJVPReturnType(builder, originalType->getResultType()); + + // No support for differentiating function that throw errors, for now. + SLANG_ASSERT(originalType->errorType->equals(builder->getBottomType())); + jvpType->errorType = originalType->errorType; - expr->type = jvpType; - } - else + for (UInt i = 0; i < originalType->getParamCount(); i++) { - // Error - expr->type = astBuilder->getErrorType(); - if (!as<ErrorType>(expr->baseFunction->type)) - { - getSink()->diagnose(expr->baseFunction->loc, Diagnostics::expectedFunction, expr->baseFunction->type); - } + if(auto jvpParamType = _toDifferentialParamType(builder, originalType->getParamType(i))) + jvpType->paramTypes.add(jvpParamType); } + return jvpType; + } + + Expr* SemanticsExprVisitor::visitJVPDifferentiateExpr(JVPDifferentiateExpr* expr) + { + this->getShared()->getDiffTypeContext()->requireDifferentiableTypeDictionary(); + + // Check/Resolve inner function declaration. + expr->baseFunction = CheckTerm(expr->baseFunction); return expr; } |
