From 1093218d6f0e114eb9fa52d60ca525bf9dd9f98a Mon Sep 17 00:00:00 2001 From: Sai Praveen Bangaru <31557731+saipraveenb25@users.noreply.github.com> Date: Thu, 20 Oct 2022 14:22:00 -0400 Subject: Modified the new type system to support generic differentiable types … (#2413) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Modified the new type system to support generic differentiable types and added support for differentiating overloaded functions. * Changed a few asserts to release asserts to avoid unreferenced variable errors * Fixed a naming issue with TypeWitnessBreadcumb::Flavor::Decl * Added logic to avoid tracking differentiable types if the module does not use auto-diff or define differentiable types. * Moved the auto-diff passes to after the specialization step, added a more complex generics test * Added a generics stress test and fixed AST-side logic. IR side needs some more work * Added differential getter and setter logic, fixed multiple issues with DifferentiableTypeDictionary, added support for loops and conditions * Changed differential getters to use pointer types, added getter type checking * Fixed some bugs related to diff type registration and differential getters * Removed some superfluous code * Removed some more unused code. * Fixed an issue with witness substitution * Minor fix Co-authored-by: Yong He --- source/slang/slang-check-expr.cpp | 278 +++++++++++++++++++++++++++++++++----- 1 file changed, 243 insertions(+), 35 deletions(-) (limited to 'source/slang/slang-check-expr.cpp') diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index f1ccddf15..745532c27 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -719,7 +719,218 @@ namespace Slang return _resolveOverloadedExprImpl(overloadedExpr, mask, getSink()); } + Type* SemanticsVisitor::_getDifferential(ASTBuilder* builder, Type* type) + { + if (auto ptrType = as(type)) + { + return builder->getPtrType( + _getDifferential(builder, ptrType->getValueType()), + ptrType->getClassInfo().m_name); + } + else if (auto arrayType = as(type)) + { + return builder->getArrayType( + _getDifferential(builder, arrayType->baseType), + arrayType->arrayLength); + } + + if (auto declRefType = as(type)) + { + if (auto witness = as(tryGetInterfaceConformanceWitness(type, builder->getDifferentiableInterface()))) + { + auto diffTypeLookupResult = lookUpMember( + getASTBuilder(), + this, + getName("Differential"), + type, + Slang::LookupMask::type, + Slang::LookupOptions::None); + + diffTypeLookupResult = resolveOverloadedLookup(diffTypeLookupResult); + + if (!diffTypeLookupResult.isValid()) + { + // Diagnose no 'Differential' member. + getSink()->diagnose(declRefType->declRef, Diagnostics::typeDoesntImplementInterfaceRequirement, type, getName("Differential")); + } + else if (diffTypeLookupResult.isOverloaded()) + { + SLANG_UNIMPLEMENTED_X("Ambiguous differential type declarations not supported"); + } + else + { + SharedTypeExpr* baseTypeExpr = m_astBuilder->create(); + baseTypeExpr->base.type = type; + baseTypeExpr->type.type = m_astBuilder->getTypeType(type); + + auto diffTypeExpr = ConstructLookupResultExpr( + diffTypeLookupResult.item, + baseTypeExpr, + declRefType->declRef.getLoc(), + baseTypeExpr); + + return ExtractTypeFromTypeRepr(diffTypeExpr); + } + } + } + + return nullptr; + } + + void SemanticsVisitor::maybeRegisterDifferentiableType(ASTBuilder* builder, Type* type) + { + if (!builder->isDifferentiableInterfaceAvailable()) + { + return; + } + + // Check for special cases such as PtrTypeBase or Array + // This could potentially be handled later by simply defining extensions + // for Ptr etc.. + // + if (auto ptrType = as(type)) + { + maybeRegisterDifferentiableType(builder, ptrType->getValueType()); + return; + } + + if (auto arrayType = as(type)) + { + maybeRegisterDifferentiableType(builder, arrayType->baseType); + return; + } + + if (auto declRefType = as(type)) + { + if (auto subtypeWitness = as( + tryGetInterfaceConformanceWitness(type, getASTBuilder()->getDifferentiableInterface()))) + { + auto diffTypeContext = this->getShared()->innermostDiffTypeContext(); + diffTypeContext->registerDifferentiableType((DeclRefType*)type, subtypeWitness); + } + + return; + } + } + + Expr* SemanticsVisitor::maybeMakeDifferentialExpr(Expr* checkedTerm) + { + // Check that member lookups on differentiable types have appropriate differential + // getters and setters. + if (auto declRefExpr = as(checkedTerm)) + { + + // Check if we have a parent container. If yes, then checkedTerm is + // referencing a member of this parent. + // + auto parentType = DeclRefType::create(getASTBuilder(), declRefExpr->declRef.getParent()); + + // Check if we have an aggregate (i.e. struct-like) type. + // Ignore interfaces and the case when the term refers to a function + // + if (parentType->declRef.as() && + !parentType->declRef.as() && + !declRefExpr->declRef.as()) + { + // Check if the parent container type is differentiable. + if (auto parentDiffWitness = as( + tryGetInterfaceConformanceWitness( + parentType, getASTBuilder()->getDifferentiableInterface()))) + { + // If yes, the member in checkedTerm should have a differential getter and setter. + // Otherwise, + // + auto diffExpr = m_astBuilder->create(); + diffExpr->type = checkedTerm->type; + diffExpr->inner = checkedTerm; + + { + auto getterName = getName("__getDifferentialFor_" + declRefExpr->name->text); + auto getterResult = lookUpMember( + getASTBuilder(), + this, + getterName, + parentType, + Slang::LookupMask::Function, + Slang::LookupOptions::None); + + if (!getterResult.isValid()) + { + // Do nothing.. we assume that this field cannot be differentiated. + // Could this be confusing from a user perspective? + } + else if (getterResult.isOverloaded()) + { + // Diagnose ambiguous getter. + SLANG_UNIMPLEMENTED_X("Ambiguous differential getters not supported"); + } + else + { + auto getterRefExpr = ConstructLookupResultExpr( + getterResult.item, + declRefExpr, + getterResult.item.declRef.getLoc(), + nullptr); + + // Check that the type is what we expect. + // We're going to do this in a very crude way for now. + // Ideally, we want to use the overload resolution and type + // coercion logic in ResolveInvoke() + // + + auto diffType = _getDifferential(m_astBuilder, checkedTerm->type.type); + auto diffParentType = _getDifferential(m_astBuilder, parentType); + + auto ptrDiffType = m_astBuilder->getPtrType(diffType); + auto inoutContainerDiffType = m_astBuilder->getInOutType(diffParentType); + + auto funcType = as(getterRefExpr->type); + + if (!ptrDiffType->equals(funcType->getResultType())) + { + getSink()->diagnose(getterRefExpr, Diagnostics::typeMismatch, + ptrDiffType, funcType->getResultType()); + } + + if (!inoutContainerDiffType->equals(funcType->getParamType(0))) + { + getSink()->diagnose(getterRefExpr, Diagnostics::typeMismatch, + inoutContainerDiffType, funcType->getParamType(0)); + } + + diffExpr->getterExpr = getterRefExpr; + } + } + + return diffExpr; + } + } + } + + return checkedTerm; + } + Expr* SemanticsVisitor::CheckTerm(Expr* term) + { + auto checkedTerm = _CheckTerm(term); + + // Differentiable type checking. + // TODO: This can be super slow. + if (this->m_parentFunc && + this->m_parentFunc->findModifier()) + { + maybeRegisterDifferentiableType(getASTBuilder(), checkedTerm->type.type); + + if (auto declRefExpr = as(checkedTerm)) + { + checkedTerm = maybeMakeDifferentialExpr(checkedTerm); + } + } + + return checkedTerm; + } + + Expr* SemanticsVisitor::_CheckTerm(Expr* term) { if (!term) return nullptr; @@ -1677,6 +1888,13 @@ namespace Slang return expr; } + Expr* SemanticsExprVisitor::visitDifferentiableDeclRefExpr(DifferentiableDeclRefExpr* expr) + { + auto checkedInnerTerm = CheckTerm(expr->inner); + expr->type = checkedInnerTerm->type; + return expr; + } + Type* SemanticsVisitor::_toDifferentialParamType(ASTBuilder* builder, Type* primalType) { @@ -1715,48 +1933,38 @@ namespace Slang return primalType; } - Expr* SemanticsExprVisitor::visitJVPDifferentiateExpr(JVPDifferentiateExpr* expr) + Type* SemanticsVisitor::processJVPFuncType(ASTBuilder* builder, FuncType* originalType) { - // Check/Resolve inner function declaration. - expr->baseFunction = CheckTerm(expr->baseFunction); + // Resolve JVP type here. + // Note that this type checking needs to be in sync with + // the auto-generation logic in slang-ir-jvp-diff.cpp - auto astBuilder = this->getASTBuilder(); + FuncType* jvpType = builder->create(); - if(auto primalType = as(expr->baseFunction->type)) - { - // Resolve JVP type here. - // Note that this type checking needs to be in sync with - // the auto-generation logic in slang-ir-jvp-diff.cpp - - FuncType* jvpType = astBuilder->create(); - - // The JVP return type is float if primal return type is float - // void otherwise. - // - jvpType->resultType = _toJVPReturnType(astBuilder, primalType->getResultType()); - - // No support for differentiating function that throw errors, for now. - SLANG_ASSERT(primalType->errorType->equals(astBuilder->getBottomType())); - jvpType->errorType = primalType->errorType; - - for (UInt i = 0; i < primalType->getParamCount(); i++) - { - if(auto jvpParamType = _toDifferentialParamType(astBuilder, primalType->getParamType(i))) - jvpType->paramTypes.add(jvpParamType); - } + // The JVP return type is float if primal return type is float + // void otherwise. + // + jvpType->resultType = _toJVPReturnType(builder, originalType->getResultType()); + + // No support for differentiating function that throw errors, for now. + SLANG_ASSERT(originalType->errorType->equals(builder->getBottomType())); + jvpType->errorType = originalType->errorType; - expr->type = jvpType; - } - else + for (UInt i = 0; i < originalType->getParamCount(); i++) { - // Error - expr->type = astBuilder->getErrorType(); - if (!as(expr->baseFunction->type)) - { - getSink()->diagnose(expr->baseFunction->loc, Diagnostics::expectedFunction, expr->baseFunction->type); - } + if(auto jvpParamType = _toDifferentialParamType(builder, originalType->getParamType(i))) + jvpType->paramTypes.add(jvpParamType); } + return jvpType; + } + + Expr* SemanticsExprVisitor::visitJVPDifferentiateExpr(JVPDifferentiateExpr* expr) + { + this->getShared()->getDiffTypeContext()->requireDifferentiableTypeDictionary(); + + // Check/Resolve inner function declaration. + expr->baseFunction = CheckTerm(expr->baseFunction); return expr; } -- cgit v1.2.3