summaryrefslogtreecommitdiffstats
path: root/source/slang/slang-check-decl.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/slang-check-decl.cpp')
-rw-r--r--source/slang/slang-check-decl.cpp124
1 files changed, 114 insertions, 10 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp
index 78eccbc8c..07c8b0cba 100644
--- a/source/slang/slang-check-decl.cpp
+++ b/source/slang/slang-check-decl.cpp
@@ -1335,6 +1335,15 @@ namespace Slang
return arrayType->isUnsized();
}
+ EnumDecl* isEnumType(Type* type)
+ {
+ if (auto declRefType = as<DeclRefType>(type))
+ {
+ return as<EnumDecl>(declRefType->getDeclRef().getDecl());
+ }
+ return nullptr;
+ }
+
bool SemanticsVisitor::shouldSkipChecking(Decl* decl, DeclCheckState state)
{
if (state < DeclCheckState::DefinitionChecked)
@@ -3591,13 +3600,14 @@ namespace Slang
}
}
- FuncDecl* SemanticsVisitor::synthesizeMethodSignatureForRequirementWitness(
+ FunctionDeclBase* SemanticsVisitor::synthesizeMethodSignatureForRequirementWitness(
ConformanceCheckingContext* context,
- DeclRef<FuncDecl> requiredMemberDeclRef,
+ DeclRef<FunctionDeclBase> requiredMemberDeclRef,
List<Expr*>& synArgs,
ThisExpr*& synThis)
{
- auto synFuncDecl = m_astBuilder->create<FuncDecl>();
+ FunctionDeclBase* synFuncDecl = as<FunctionDeclBase>(m_astBuilder->createByNodeType(requiredMemberDeclRef.getDecl()->astNodeType));
+ SLANG_ASSERT(synFuncDecl);
synFuncDecl->ownedScope = m_astBuilder->create<Scope>();
synFuncDecl->ownedScope->containerDecl = synFuncDecl;
synFuncDecl->ownedScope->parent = getScope(context->parentDecl);
@@ -3913,6 +3923,14 @@ namespace Slang
{
SLANG_UNUSED(satisfyingMemberLookupResult);
+ if (as<EnumDecl>(context->parentDecl))
+ {
+ if (auto builtinRequirement = requiredMemberDeclRef.getDecl()->findModifier<BuiltinRequirementModifier>())
+ {
+ return trySynthesizeEnumTypeMethodRequirementWitness(context, requiredMemberDeclRef, witnessTable, builtinRequirement->kind);
+ }
+ }
+
bool isInWrapperType = isWrapperTypeDecl(context->parentDecl);
if (!isInWrapperType)
{
@@ -4593,6 +4611,21 @@ namespace Slang
requiredFuncDeclRef,
witnessTable,
SynthesisPattern::AllInductive);
+ case BuiltinRequirementKind::And:
+ case BuiltinRequirementKind::Or:
+ case BuiltinRequirementKind::Not:
+ case BuiltinRequirementKind::BitAnd:
+ case BuiltinRequirementKind::BitNot:
+ case BuiltinRequirementKind::BitOr:
+ case BuiltinRequirementKind::BitXor:
+ case BuiltinRequirementKind::Shl:
+ case BuiltinRequirementKind::Shr:
+ case BuiltinRequirementKind::Equals:
+ case BuiltinRequirementKind::LessThan:
+ case BuiltinRequirementKind::LessThanOrEquals:
+ if (isEnumType(context->conformingType))
+ return trySynthesizeEnumTypeMethodRequirementWitness(context, requiredFuncDeclRef, witnessTable, builtinAttr->kind);
+ break;
}
}
return false;
@@ -4739,6 +4772,70 @@ namespace Slang
return synth.emitAssignStmt(leftValue, synth.emitInvokeExpr(callee, _Move(args)));
}
+ bool SemanticsVisitor::trySynthesizeEnumTypeMethodRequirementWitness(ConformanceCheckingContext* context,
+ DeclRef<FunctionDeclBase> funcDeclRef,
+ RefPtr<WitnessTable> witnessTable,
+ BuiltinRequirementKind requirementKind)
+ {
+ List<Expr*> synArgs;
+ ThisExpr* synThis = nullptr;
+ auto synFunc = synthesizeMethodSignatureForRequirementWitness(
+ context, funcDeclRef, synArgs, synThis);
+ auto intrinsicOpModifier = getASTBuilder()->create<IntrinsicOpModifier>();
+ switch (requirementKind)
+ {
+ case BuiltinRequirementKind::And:
+ intrinsicOpModifier->op = kIROp_And;
+ break;
+ case BuiltinRequirementKind::Or:
+ intrinsicOpModifier->op = kIROp_Or;
+ break;
+ case BuiltinRequirementKind::Not:
+ intrinsicOpModifier->op = kIROp_Not;
+ break;
+ case BuiltinRequirementKind::BitAnd:
+ intrinsicOpModifier->op = kIROp_BitAnd;
+ break;
+ case BuiltinRequirementKind::BitNot:
+ intrinsicOpModifier->op = kIROp_BitNot;
+ break;
+ case BuiltinRequirementKind::BitOr:
+ intrinsicOpModifier->op = kIROp_BitOr;
+ break;
+ case BuiltinRequirementKind::BitXor:
+ intrinsicOpModifier->op = kIROp_BitXor;
+ break;
+ case BuiltinRequirementKind::Shl:
+ intrinsicOpModifier->op = kIROp_Lsh;
+ break;
+ case BuiltinRequirementKind::Shr:
+ intrinsicOpModifier->op = kIROp_Rsh;
+ break;
+ case BuiltinRequirementKind::Equals:
+ intrinsicOpModifier->op = kIROp_Eql;
+ break;
+ case BuiltinRequirementKind::LessThan:
+ intrinsicOpModifier->op = kIROp_Less;
+ break;
+ case BuiltinRequirementKind::LessThanOrEquals:
+ intrinsicOpModifier->op = kIROp_Leq;
+ break;
+ case BuiltinRequirementKind::InitLogicalFromInt:
+ intrinsicOpModifier->op = kIROp_IntCast;
+ break;
+ default:
+ SLANG_ASSERT("unknown builtin requirement kind.");
+ }
+ synFunc->parentDecl = context->parentDecl;
+ synFunc->loc = context->parentDecl->closingSourceLoc;
+ synFunc->nameAndLoc.loc = synFunc->loc;
+ context->parentDecl->members.add(synFunc);
+ context->parentDecl->invalidateMemberDictionary();
+ addModifier(synFunc, intrinsicOpModifier);
+ witnessTable->add(funcDeclRef.getDecl(), RequirementWitness(m_astBuilder->getDirectDeclRef(synFunc)));
+ return true;
+ }
+
bool SemanticsVisitor::trySynthesizeDifferentialMethodRequirementWitness(
ConformanceCheckingContext* context,
DeclRef<Decl> requirementDeclRef,
@@ -4801,8 +4898,8 @@ namespace Slang
}
else if (auto funcDeclRef = requirementDeclRef.as<FuncDecl>())
{
- synFunc = synthesizeMethodSignatureForRequirementWitness(
- context, funcDeclRef, synArgs, synThis);
+ synFunc = as<FuncDecl>(synthesizeMethodSignatureForRequirementWitness(
+ context, funcDeclRef, synArgs, synThis));
}
SLANG_ASSERT(synFunc);
@@ -6083,6 +6180,8 @@ namespace Slang
auto tagType = decl->tagType;
+ auto isEnumFlags = decl->hasModifier<FlagsAttribute>();
+
// Check the enum cases in order.
for(auto caseDecl : decl->getMembersOfType<EnumCaseDecl>())
{
@@ -6102,7 +6201,7 @@ namespace Slang
// For any enum case that didn't provide an explicit
// tag value, derived an appropriate tag value.
- IntegerLiteralValue defaultTag = 0;
+ IntegerLiteralValue defaultTag = isEnumFlags ? 1 : 0;
for(auto caseDecl : decl->getMembersOfType<EnumCaseDecl>())
{
if(auto explicitTagValExpr = caseDecl->tagExpr)
@@ -6146,10 +6245,15 @@ namespace Slang
// Default tag for the next case will be one more than
// for the most recent case.
//
- // TODO: We might consider adding a `[flags]` attribute
- // that modifies this behavior to be `defaultTagForCase <<= 1`.
- //
- defaultTag++;
+ if (!isEnumFlags)
+ defaultTag++;
+ else
+ {
+ if (defaultTag == 0)
+ defaultTag = 1;
+ else
+ defaultTag <<= 1;
+ }
}
}