diff options
Diffstat (limited to 'source/slang/slang-check-decl.cpp')
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 124 |
1 files changed, 114 insertions, 10 deletions
diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 78eccbc8c..07c8b0cba 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -1335,6 +1335,15 @@ namespace Slang return arrayType->isUnsized(); } + EnumDecl* isEnumType(Type* type) + { + if (auto declRefType = as<DeclRefType>(type)) + { + return as<EnumDecl>(declRefType->getDeclRef().getDecl()); + } + return nullptr; + } + bool SemanticsVisitor::shouldSkipChecking(Decl* decl, DeclCheckState state) { if (state < DeclCheckState::DefinitionChecked) @@ -3591,13 +3600,14 @@ namespace Slang } } - FuncDecl* SemanticsVisitor::synthesizeMethodSignatureForRequirementWitness( + FunctionDeclBase* SemanticsVisitor::synthesizeMethodSignatureForRequirementWitness( ConformanceCheckingContext* context, - DeclRef<FuncDecl> requiredMemberDeclRef, + DeclRef<FunctionDeclBase> requiredMemberDeclRef, List<Expr*>& synArgs, ThisExpr*& synThis) { - auto synFuncDecl = m_astBuilder->create<FuncDecl>(); + FunctionDeclBase* synFuncDecl = as<FunctionDeclBase>(m_astBuilder->createByNodeType(requiredMemberDeclRef.getDecl()->astNodeType)); + SLANG_ASSERT(synFuncDecl); synFuncDecl->ownedScope = m_astBuilder->create<Scope>(); synFuncDecl->ownedScope->containerDecl = synFuncDecl; synFuncDecl->ownedScope->parent = getScope(context->parentDecl); @@ -3913,6 +3923,14 @@ namespace Slang { SLANG_UNUSED(satisfyingMemberLookupResult); + if (as<EnumDecl>(context->parentDecl)) + { + if (auto builtinRequirement = requiredMemberDeclRef.getDecl()->findModifier<BuiltinRequirementModifier>()) + { + return trySynthesizeEnumTypeMethodRequirementWitness(context, requiredMemberDeclRef, witnessTable, builtinRequirement->kind); + } + } + bool isInWrapperType = isWrapperTypeDecl(context->parentDecl); if (!isInWrapperType) { @@ -4593,6 +4611,21 @@ namespace Slang requiredFuncDeclRef, witnessTable, SynthesisPattern::AllInductive); + case BuiltinRequirementKind::And: + case BuiltinRequirementKind::Or: + case BuiltinRequirementKind::Not: + case BuiltinRequirementKind::BitAnd: + case BuiltinRequirementKind::BitNot: + case BuiltinRequirementKind::BitOr: + case BuiltinRequirementKind::BitXor: + case BuiltinRequirementKind::Shl: + case BuiltinRequirementKind::Shr: + case BuiltinRequirementKind::Equals: + case BuiltinRequirementKind::LessThan: + case BuiltinRequirementKind::LessThanOrEquals: + if (isEnumType(context->conformingType)) + return trySynthesizeEnumTypeMethodRequirementWitness(context, requiredFuncDeclRef, witnessTable, builtinAttr->kind); + break; } } return false; @@ -4739,6 +4772,70 @@ namespace Slang return synth.emitAssignStmt(leftValue, synth.emitInvokeExpr(callee, _Move(args))); } + bool SemanticsVisitor::trySynthesizeEnumTypeMethodRequirementWitness(ConformanceCheckingContext* context, + DeclRef<FunctionDeclBase> funcDeclRef, + RefPtr<WitnessTable> witnessTable, + BuiltinRequirementKind requirementKind) + { + List<Expr*> synArgs; + ThisExpr* synThis = nullptr; + auto synFunc = synthesizeMethodSignatureForRequirementWitness( + context, funcDeclRef, synArgs, synThis); + auto intrinsicOpModifier = getASTBuilder()->create<IntrinsicOpModifier>(); + switch (requirementKind) + { + case BuiltinRequirementKind::And: + intrinsicOpModifier->op = kIROp_And; + break; + case BuiltinRequirementKind::Or: + intrinsicOpModifier->op = kIROp_Or; + break; + case BuiltinRequirementKind::Not: + intrinsicOpModifier->op = kIROp_Not; + break; + case BuiltinRequirementKind::BitAnd: + intrinsicOpModifier->op = kIROp_BitAnd; + break; + case BuiltinRequirementKind::BitNot: + intrinsicOpModifier->op = kIROp_BitNot; + break; + case BuiltinRequirementKind::BitOr: + intrinsicOpModifier->op = kIROp_BitOr; + break; + case BuiltinRequirementKind::BitXor: + intrinsicOpModifier->op = kIROp_BitXor; + break; + case BuiltinRequirementKind::Shl: + intrinsicOpModifier->op = kIROp_Lsh; + break; + case BuiltinRequirementKind::Shr: + intrinsicOpModifier->op = kIROp_Rsh; + break; + case BuiltinRequirementKind::Equals: + intrinsicOpModifier->op = kIROp_Eql; + break; + case BuiltinRequirementKind::LessThan: + intrinsicOpModifier->op = kIROp_Less; + break; + case BuiltinRequirementKind::LessThanOrEquals: + intrinsicOpModifier->op = kIROp_Leq; + break; + case BuiltinRequirementKind::InitLogicalFromInt: + intrinsicOpModifier->op = kIROp_IntCast; + break; + default: + SLANG_ASSERT("unknown builtin requirement kind."); + } + synFunc->parentDecl = context->parentDecl; + synFunc->loc = context->parentDecl->closingSourceLoc; + synFunc->nameAndLoc.loc = synFunc->loc; + context->parentDecl->members.add(synFunc); + context->parentDecl->invalidateMemberDictionary(); + addModifier(synFunc, intrinsicOpModifier); + witnessTable->add(funcDeclRef.getDecl(), RequirementWitness(m_astBuilder->getDirectDeclRef(synFunc))); + return true; + } + bool SemanticsVisitor::trySynthesizeDifferentialMethodRequirementWitness( ConformanceCheckingContext* context, DeclRef<Decl> requirementDeclRef, @@ -4801,8 +4898,8 @@ namespace Slang } else if (auto funcDeclRef = requirementDeclRef.as<FuncDecl>()) { - synFunc = synthesizeMethodSignatureForRequirementWitness( - context, funcDeclRef, synArgs, synThis); + synFunc = as<FuncDecl>(synthesizeMethodSignatureForRequirementWitness( + context, funcDeclRef, synArgs, synThis)); } SLANG_ASSERT(synFunc); @@ -6083,6 +6180,8 @@ namespace Slang auto tagType = decl->tagType; + auto isEnumFlags = decl->hasModifier<FlagsAttribute>(); + // Check the enum cases in order. for(auto caseDecl : decl->getMembersOfType<EnumCaseDecl>()) { @@ -6102,7 +6201,7 @@ namespace Slang // For any enum case that didn't provide an explicit // tag value, derived an appropriate tag value. - IntegerLiteralValue defaultTag = 0; + IntegerLiteralValue defaultTag = isEnumFlags ? 1 : 0; for(auto caseDecl : decl->getMembersOfType<EnumCaseDecl>()) { if(auto explicitTagValExpr = caseDecl->tagExpr) @@ -6146,10 +6245,15 @@ namespace Slang // Default tag for the next case will be one more than // for the most recent case. // - // TODO: We might consider adding a `[flags]` attribute - // that modifies this behavior to be `defaultTagForCase <<= 1`. - // - defaultTag++; + if (!isEnumFlags) + defaultTag++; + else + { + if (defaultTag == 0) + defaultTag = 1; + else + defaultTag <<= 1; + } } } |
