summaryrefslogtreecommitdiff
path: root/source/slang/slang-check-decl.cpp
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2022-11-29 14:16:05 -0800
committerGitHub <noreply@github.com>2022-11-29 14:16:05 -0800
commitaf7f40063dfed1c651d33b93956c7623a7d2c050 (patch)
treefd2225cffda6a9a887051bd00c5ccaba9ec6b5ea /source/slang/slang-check-decl.cpp
parentd85c7b809d02e6dc0844aab07e66a6bac2462017 (diff)
Complete removal of DifferentialBottom type. (#2537)
Co-authored-by: Yong He <yhe@nvidia.com>
Diffstat (limited to 'source/slang/slang-check-decl.cpp')
-rw-r--r--source/slang/slang-check-decl.cpp122
1 files changed, 42 insertions, 80 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 4d2839b8d..5e6c6eedf 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -259,8 +259,6 @@ namespace Slang
void visitParamDecl(ParamDecl* paramDecl);
- void _maybeRegisterDifferentialBottomTypeConformance(SemanticsContext& context);
-
void checkDerivativeOfAttribute(FunctionDeclBase* funcDecl);
void checkDerivativeAttribute(FunctionDeclBase* funcDecl, ForwardDerivativeAttribute* attr);
@@ -3317,73 +3315,52 @@ namespace Slang
auto seqStmt = synth.pushSeqStmtScope();
blockStmt->body = seqStmt;
- if (synFunc->returnType.type->equals(m_astBuilder->getDifferentialBottomType()))
- {
- // Trivial case, the `Differential` type is `DifferentialBottom`.
- // We will just return `DifferentialBottom.dzero()`.
- auto resultExpr = m_astBuilder->create<InvokeExpr>();
- auto dzeroMember = m_astBuilder->create<StaticMemberExpr>();
- auto base = m_astBuilder->create<SharedTypeExpr>();
- auto typetype = m_astBuilder->create<TypeType>();
- typetype->type = m_astBuilder->getDifferentialBottomType();
- base->type.type = typetype;
- dzeroMember->baseExpression = base;
- dzeroMember->name = getName("dzero");
- resultExpr->functionExpr = dzeroMember;
- auto synReturn = m_astBuilder->create<ReturnStmt>();
- synReturn->expression = resultExpr;
- seqStmt->stmts.add(synReturn);
- }
- else
- {
- // The general case.
- // Create a variable for return value.
- synth.pushVarScope();
- auto varStmt = synth.emitVarDeclStmt(synFunc->returnType.type, getName("result"));
- auto resultVarExpr = synth.emitVarExpr(varStmt, synFunc->returnType.type);
+ // Create a variable for return value.
+ synth.pushVarScope();
+ auto varStmt = synth.emitVarDeclStmt(synFunc->returnType.type, getName("result"));
+ auto resultVarExpr = synth.emitVarExpr(varStmt, synFunc->returnType.type);
- for (auto member : context->parentDecl->members)
- {
- auto derivativeAttr = member->findModifier<DerivativeMemberAttribute>();
- if (!derivativeAttr)
- continue;
- auto varMember = as<VarDeclBase>(member);
- if (!varMember)
- continue;
- ensureDecl(varMember, DeclCheckState::ReadyForReference);
- auto memberType = varMember->getType();
- auto diffMemberType = tryGetDifferentialType(m_astBuilder, memberType);
- if (!diffMemberType)
- continue;
-
- // Construct reference exprs to the member's corresponding fields in each parameter.
- List<Expr*> paramFields;
- int paramIndex = 0;
- for (auto arg : synArgs)
- {
- auto memberExpr = m_astBuilder->create<MemberExpr>();
- memberExpr->baseExpression = arg;
- // TODO: we should probably fetch the name from `[DerivativeMember]` if `arg` is
- // Differential type.
- memberExpr->name = varMember->getName();
- paramFields.add(memberExpr);
- paramIndex++;
- }
+ for (auto member : context->parentDecl->members)
+ {
+ auto derivativeAttr = member->findModifier<DerivativeMemberAttribute>();
+ if (!derivativeAttr)
+ continue;
+ auto varMember = as<VarDeclBase>(member);
+ if (!varMember)
+ continue;
+ ensureDecl(varMember, DeclCheckState::ReadyForReference);
+ auto memberType = varMember->getType();
+ auto diffMemberType = tryGetDifferentialType(m_astBuilder, memberType);
+ if (!diffMemberType)
+ continue;
- // Invoke the method for the field and assign the value to resultVar.
- // TODO: we should probably fetch the name from `[DerivativeMember]` if `resultVarExpr`
- // is Differential type.
- auto leftVal = synth.emitMemberExpr(resultVarExpr, varMember->getName());
- if (!_synthesizeMemberAssignMemberHelper(synth, requirementDeclRef.getName(), memberType, leftVal, _Move(paramFields)))
- return false;
- }
+ // Construct reference exprs to the member's corresponding fields in each parameter.
+ List<Expr*> paramFields;
+ int paramIndex = 0;
+ for (auto arg : synArgs)
+ {
+ auto memberExpr = m_astBuilder->create<MemberExpr>();
+ memberExpr->baseExpression = arg;
+ // TODO: we should probably fetch the name from `[DerivativeMember]` if `arg` is
+ // Differential type.
+ memberExpr->name = varMember->getName();
+ paramFields.add(memberExpr);
+ paramIndex++;
+ }
+
+ // Invoke the method for the field and assign the value to resultVar.
+ // TODO: we should probably fetch the name from `[DerivativeMember]` if `resultVarExpr`
+ // is Differential type.
+ auto leftVal = synth.emitMemberExpr(resultVarExpr, varMember->getName());
+ if (!_synthesizeMemberAssignMemberHelper(synth, requirementDeclRef.getName(), memberType, leftVal, _Move(paramFields)))
+ return false;
+ }
- // TODO: synthesize assignments for inherited members here.
+ // TODO: synthesize assignments for inherited members here.
- auto synReturn = m_astBuilder->create<ReturnStmt>();
- synReturn->expression = resultVarExpr;
- seqStmt->stmts.add(synReturn);
- }
+ auto synReturn = m_astBuilder->create<ReturnStmt>();
+ synReturn->expression = resultVarExpr;
+ seqStmt->stmts.add(synReturn);
context->parentDecl->members.add(synFunc);
context->parentDecl->invalidateMemberDictionary();
@@ -4633,21 +4610,6 @@ namespace Slang
getSink()->diagnose(decl, Slang::Diagnostics::assocTypeInInterfaceOnly);
}
- void SemanticsDeclBodyVisitor::_maybeRegisterDifferentialBottomTypeConformance(SemanticsContext& context)
- {
- auto parentDifferentiableAttr = context.getParentDifferentiableAttribute();
- if (parentDifferentiableAttr)
- {
- auto diffBottomType = m_astBuilder->getDifferentialBottomType();
- auto idifferentiable = DeclRef<InterfaceDecl>(m_astBuilder->getDifferentiableInterface(), nullptr);
- auto witness = as<SubtypeWitness>(tryGetInterfaceConformanceWitness(diffBottomType, idifferentiable));
- SLANG_ASSERT(witness);
- parentDifferentiableAttr->m_mapTypeToIDifferentiableWitness.Add(
- as<DeclRefType>(diffBottomType)->declRef,
- witness);
- }
- }
-
void SemanticsDeclBodyVisitor::checkDerivativeOfAttribute(FunctionDeclBase* funcDecl)
{
auto attr = funcDecl->findModifier<ForwardDerivativeOfAttribute>();