diff options
| author | Yong He <yonghe@outlook.com> | 2022-11-29 14:16:05 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2022-11-29 14:16:05 -0800 |
| commit | af7f40063dfed1c651d33b93956c7623a7d2c050 (patch) | |
| tree | fd2225cffda6a9a887051bd00c5ccaba9ec6b5ea /source/slang/slang-check-decl.cpp | |
| parent | d85c7b809d02e6dc0844aab07e66a6bac2462017 (diff) | |
Complete removal of DifferentialBottom type. (#2537)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-check-decl.cpp')
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 122 |
1 files changed, 42 insertions, 80 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 4d2839b8d..5e6c6eedf 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -259,8 +259,6 @@ namespace Slang void visitParamDecl(ParamDecl* paramDecl); - void _maybeRegisterDifferentialBottomTypeConformance(SemanticsContext& context); - void checkDerivativeOfAttribute(FunctionDeclBase* funcDecl); void checkDerivativeAttribute(FunctionDeclBase* funcDecl, ForwardDerivativeAttribute* attr); @@ -3317,73 +3315,52 @@ namespace Slang auto seqStmt = synth.pushSeqStmtScope(); blockStmt->body = seqStmt; - if (synFunc->returnType.type->equals(m_astBuilder->getDifferentialBottomType())) - { - // Trivial case, the `Differential` type is `DifferentialBottom`. - // We will just return `DifferentialBottom.dzero()`. - auto resultExpr = m_astBuilder->create<InvokeExpr>(); - auto dzeroMember = m_astBuilder->create<StaticMemberExpr>(); - auto base = m_astBuilder->create<SharedTypeExpr>(); - auto typetype = m_astBuilder->create<TypeType>(); - typetype->type = m_astBuilder->getDifferentialBottomType(); - base->type.type = typetype; - dzeroMember->baseExpression = base; - dzeroMember->name = getName("dzero"); - resultExpr->functionExpr = dzeroMember; - auto synReturn = m_astBuilder->create<ReturnStmt>(); - synReturn->expression = resultExpr; - seqStmt->stmts.add(synReturn); - } - else - { - // The general case. - // Create a variable for return value. - synth.pushVarScope(); - auto varStmt = synth.emitVarDeclStmt(synFunc->returnType.type, getName("result")); - auto resultVarExpr = synth.emitVarExpr(varStmt, synFunc->returnType.type); + // Create a variable for return value. + synth.pushVarScope(); + auto varStmt = synth.emitVarDeclStmt(synFunc->returnType.type, getName("result")); + auto resultVarExpr = synth.emitVarExpr(varStmt, synFunc->returnType.type); - for (auto member : context->parentDecl->members) - { - auto derivativeAttr = member->findModifier<DerivativeMemberAttribute>(); - if (!derivativeAttr) - continue; - auto varMember = as<VarDeclBase>(member); - if (!varMember) - continue; - ensureDecl(varMember, DeclCheckState::ReadyForReference); - auto memberType = varMember->getType(); - auto diffMemberType = tryGetDifferentialType(m_astBuilder, memberType); - if (!diffMemberType) - continue; - - // Construct reference exprs to the member's corresponding fields in each parameter. - List<Expr*> paramFields; - int paramIndex = 0; - for (auto arg : synArgs) - { - auto memberExpr = m_astBuilder->create<MemberExpr>(); - memberExpr->baseExpression = arg; - // TODO: we should probably fetch the name from `[DerivativeMember]` if `arg` is - // Differential type. - memberExpr->name = varMember->getName(); - paramFields.add(memberExpr); - paramIndex++; - } + for (auto member : context->parentDecl->members) + { + auto derivativeAttr = member->findModifier<DerivativeMemberAttribute>(); + if (!derivativeAttr) + continue; + auto varMember = as<VarDeclBase>(member); + if (!varMember) + continue; + ensureDecl(varMember, DeclCheckState::ReadyForReference); + auto memberType = varMember->getType(); + auto diffMemberType = tryGetDifferentialType(m_astBuilder, memberType); + if (!diffMemberType) + continue; - // Invoke the method for the field and assign the value to resultVar. - // TODO: we should probably fetch the name from `[DerivativeMember]` if `resultVarExpr` - // is Differential type. - auto leftVal = synth.emitMemberExpr(resultVarExpr, varMember->getName()); - if (!_synthesizeMemberAssignMemberHelper(synth, requirementDeclRef.getName(), memberType, leftVal, _Move(paramFields))) - return false; - } + // Construct reference exprs to the member's corresponding fields in each parameter. + List<Expr*> paramFields; + int paramIndex = 0; + for (auto arg : synArgs) + { + auto memberExpr = m_astBuilder->create<MemberExpr>(); + memberExpr->baseExpression = arg; + // TODO: we should probably fetch the name from `[DerivativeMember]` if `arg` is + // Differential type. + memberExpr->name = varMember->getName(); + paramFields.add(memberExpr); + paramIndex++; + } + + // Invoke the method for the field and assign the value to resultVar. + // TODO: we should probably fetch the name from `[DerivativeMember]` if `resultVarExpr` + // is Differential type. + auto leftVal = synth.emitMemberExpr(resultVarExpr, varMember->getName()); + if (!_synthesizeMemberAssignMemberHelper(synth, requirementDeclRef.getName(), memberType, leftVal, _Move(paramFields))) + return false; + } - // TODO: synthesize assignments for inherited members here. + // TODO: synthesize assignments for inherited members here. - auto synReturn = m_astBuilder->create<ReturnStmt>(); - synReturn->expression = resultVarExpr; - seqStmt->stmts.add(synReturn); - } + auto synReturn = m_astBuilder->create<ReturnStmt>(); + synReturn->expression = resultVarExpr; + seqStmt->stmts.add(synReturn); context->parentDecl->members.add(synFunc); context->parentDecl->invalidateMemberDictionary(); @@ -4633,21 +4610,6 @@ namespace Slang getSink()->diagnose(decl, Slang::Diagnostics::assocTypeInInterfaceOnly); } - void SemanticsDeclBodyVisitor::_maybeRegisterDifferentialBottomTypeConformance(SemanticsContext& context) - { - auto parentDifferentiableAttr = context.getParentDifferentiableAttribute(); - if (parentDifferentiableAttr) - { - auto diffBottomType = m_astBuilder->getDifferentialBottomType(); - auto idifferentiable = DeclRef<InterfaceDecl>(m_astBuilder->getDifferentiableInterface(), nullptr); - auto witness = as<SubtypeWitness>(tryGetInterfaceConformanceWitness(diffBottomType, idifferentiable)); - SLANG_ASSERT(witness); - parentDifferentiableAttr->m_mapTypeToIDifferentiableWitness.Add( - as<DeclRefType>(diffBottomType)->declRef, - witness); - } - } - void SemanticsDeclBodyVisitor::checkDerivativeOfAttribute(FunctionDeclBase* funcDecl) { auto attr = funcDecl->findModifier<ForwardDerivativeOfAttribute>(); |
