diff options
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/hlsl.meta.slang | 13 | ||||
| -rw-r--r-- | source/slang/slang-check-decl.cpp | 133 | ||||
| -rw-r--r-- | source/slang/slang-check-impl.h | 13 | ||||
| -rw-r--r-- | source/slang/slang-ir-hlsl-legalize.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-ir-link.cpp | 4 | ||||
| -rw-r--r-- | source/slang/slang-ir-use-uninitialized-values.cpp | 2 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 6 | ||||
| -rw-r--r-- | source/slang/slang-serialize-container.cpp | 16 |
9 files changed, 127 insertions, 64 deletions
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index a10e747c0..81b28b30a 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -524,6 +524,13 @@ interface ITexelElement __init(Element x); } +extension<T:__BuiltinArithmeticType> T : ITexelElement +{ + typealias Element = T; + static const int elementCount = 1; + __intrinsic_op(0) __init(Element x); +} + ${{{ // Scalar types that can be used as texel element. const char* texeElementScalarTypes[] = { @@ -539,12 +546,6 @@ const char* texeElementScalarTypes[] = { for (auto elementType : texeElementScalarTypes) { }}} -extension $(elementType) : ITexelElement -{ - typealias Element = $(elementType); - static const int elementCount = 1; - __intrinsic_op(0) __init(Element x); -} extension<int N> vector<$(elementType), N> : ITexelElement { typealias Element = $(elementType); diff --git a/source/slang/slang-check-decl.cpp b/source/slang/slang-check-decl.cpp index 76074f551..1ef5b1cec 100644 --- a/source/slang/slang-check-decl.cpp +++ b/source/slang/slang-check-decl.cpp @@ -9789,6 +9789,46 @@ void SemanticsVisitor::validateArraySizeForVariable(VarDeclBase* varDecl) } } +bool getExtensionTargetDeclList( + ASTBuilder* astBuilder, + DeclRefType* targetDeclRefType, + ExtensionDecl* extDecl, + ShortList<AggTypeDecl*>& targetDecls) +{ + if (auto aggTypeDeclRef = targetDeclRefType->getDeclRef().as<AggTypeDecl>()) + { + auto aggTypeDecl = aggTypeDeclRef.getDecl(); + + targetDecls.add(aggTypeDecl); + return true; + } + + auto genericParamDeclRef = targetDeclRefType->getDeclRef().as<GenericTypeParamDeclBase>(); + if (!genericParamDeclRef) + return false; + + auto genericParent = as<GenericDecl>(genericParamDeclRef.getParent().getDecl()); + if (!genericParent) + return false; + + if (genericParent != extDecl->parentDecl) + return false; + + for (auto member : getMembersOfType<GenericTypeConstraintDecl>(astBuilder, genericParent)) + { + if (getSub(astBuilder, member) == targetDeclRefType) + { + auto baseType = getSup(astBuilder, member); + if (auto baseTypeDecl = isDeclRefTypeOf<AggTypeDecl>(baseType)) + { + targetDecls.add(baseTypeDecl.getDecl()); + } + } + } + return targetDecls.getCount() != 0; +} + + void SemanticsDeclBasesVisitor::_validateExtensionDeclTargetType(ExtensionDecl* decl) { if (auto targetDeclRefType = as<DeclRefType>(decl->targetType)) @@ -11582,8 +11622,8 @@ void checkDerivativeAttributeImpl( auto derivativeFuncThisType = getTypeForThisExpr(visitor, calleeFuncDeclRef); // If the function is a member function, we need to check that the - // `this` type matches the expected type. This will ensure that after lowering to - // IR, the two functions are compatible. + // `this` type matches the expected type. This will ensure that after lowering + // to IR, the two functions are compatible. // if (!areTypesCompatibile(visitor, funcThisType, derivativeFuncThisType)) { @@ -11971,8 +12011,9 @@ void checkDerivativeOfAttributeImpl( if (as<ErrorType>(resolved->type.type)) { - // If we can't resolve a type, something went wrong. If we're working with a generic - // decl, the most likely cause is a failure of generic argument inference. + // If we can't resolve a type, something went wrong. If we're working with a + // generic decl, the most likely cause is a failure of generic argument + // inference. // visitor->getSink()->diagnose( derivativeOfAttr, @@ -12284,8 +12325,8 @@ bool SemanticsDeclAttributesVisitor::collectInitializableMembers( // Find the base type's members first for (auto inheritanceMember : structDecl->getMembersOfType<InheritanceDecl>()) { - // For base types, we need to pick their parameters of the constructor to the derived type's - // constructor + // For base types, we need to pick their parameters of the constructor to the derived + // type's constructor if (auto baseTypeDeclRef = isDeclRefTypeOf<StructDecl>(inheritanceMember->base.type)) { // We should only find the member initialization constructor because it is the @@ -12294,15 +12335,15 @@ bool SemanticsDeclAttributesVisitor::collectInitializableMembers( baseTypeDeclRef.getDecl(), ConstructorDecl::ConstructorFlavor::SynthesizedMemberInit); - // The constructor has to have higher or equal visibility level than the struct itself, - // otherwise, it's not accessible so we will not pick up. + // The constructor has to have higher or equal visibility level than the struct + // itself, otherwise, it's not accessible so we will not pick up. if (ctor && getDeclVisibility(ctor) >= ctorVisibility) { for (ParamDecl* param : ctor->getParameters()) { - // Because the parameters in the ctor must have the higher or equal visibility - // than the ctor itself, we don't need to check the visibility level of the - // parameter. + // Because the parameters in the ctor must have the higher or equal + // visibility than the ctor itself, we don't need to check the visibility + // level of the parameter. resultMembers.add(param); } } @@ -12342,10 +12383,9 @@ static Expr* _getParamDefaultValue(SemanticsVisitor* visitor, VarDeclBase* varDe bool SemanticsDeclAttributesVisitor::_synthesizeCtorSignature(StructDecl* structDecl) { - // If a type or its base type already defines any explicit constructors, do not synthesize any - // constructors. - // See - // https://github.com/shader-slang/spec/blob/main/proposals/004-initialization.md#inheritance-initialization + // If a type or its base type already defines any explicit constructors, do not synthesize + // any constructors. see: + // https://github.com/shader-slang/slang/blob/master/docs/proposals/004-initialization.md#inheritance-initialization if (_hasExplicitConstructor(structDecl, true)) return false; @@ -12397,9 +12437,9 @@ bool SemanticsDeclAttributesVisitor::_synthesizeCtorSignature(StructDecl* struct ctorParam->loc = ctor->loc; ctor->members.add(ctorParam); - // We need to ensure member is `no_diff` if it cannot be differentiated, `ctor` modifiers do - // not matter in this case since member-wise ctor is always differentiable or "treat as - // differentiable". + // We need to ensure member is `no_diff` if it cannot be differentiated, `ctor` + // modifiers do not matter in this case since member-wise ctor is always differentiable + // or "treat as differentiable". if (!isTypeDifferentiable(member->getType()) || member->hasModifier<NoDiffModifier>()) { auto noDiffMod = m_astBuilder->create<NoDiffModifier>(); @@ -12559,7 +12599,8 @@ void SemanticsDeclAttributesVisitor::visitStructDecl(StructDecl* structDecl) totalWidth += int(thisFieldWidth); groupInfo.add({memberIndex, int(thisFieldWidth), t, bfm}); } - // If the struct ended with a bitpacked member, then make sure we don't forget the last group + // If the struct ended with a bitpacked member, then make sure we don't forget the last + // group dispatchSomeBitPackedMembers(); } @@ -12630,8 +12671,8 @@ static void _propagateRequirement( if (!isAnyInvalid && resultCaps.isInvalid()) { // If joining the referenced decl's requirements results an invalid capability set, - // then the decl is using things that require conflicting set of capabilities, and we should - // diagnose an error. + // then the decl is using things that require conflicting set of capabilities, and we + // should diagnose an error. if (referencedDecl && decl) { maybeDiagnose( @@ -12736,17 +12777,17 @@ struct CapabilityDeclReferenceVisitor // `calling_functions_targets`: // ``` default_target = calling_functions_targets-{other_case_targets} ``` // - // * `calling_functions_capability` = `requirement attribute` of the calling function; - // if missing + // * `calling_functions_capability` = `requirement attribute` of the calling + // function; if missing // we can assume it is `any_target` // - // * `{other_case_targets}` = set of all capabilities all `case` statments target inside - // the `__target_switch` + // * `{other_case_targets}` = set of all capabilities all `case` statments target + // inside the `__target_switch` - // If we do not handle `default:`, the codegen will fail when trying to find a specific - // codegen target not handled explicitly by a `case` statment. - // We must also ensure the `default` case is last so we have priority to hit `case` - // statments and can preprocess `case` statments before the `default` case. + // If we do not handle `default:`, the codegen will fail when trying to find a + // specific codegen target not handled explicitly by a `case` statment. We must also + // ensure the `default` case is last so we have priority to hit `case` statments and + // can preprocess `case` statments before the `default` case. CapabilitySet targetCap; if (CapabilityName(stmt->targetCases[targetCaseIndex]->capability) == CapabilityName::Invalid) @@ -12901,7 +12942,8 @@ CapabilitySet SemanticsDeclCapabilityVisitor::getDeclaredCapabilitySet(Decl* dec // For every existing target, we want to join their requirements together. // If the the parent defines additional targets, we want to add them to the disjunction set. // For example: - // [require(glsl)] struct Parent { [require(glsl, glsl_ext_1)] [require(spirv)] void foo(); } + // [require(glsl)] struct Parent { [require(glsl, glsl_ext_1)] [require(spirv)] void + // foo(); } // The requirement for `foo` should be glsl+glsl_ext_1 | spirv. // CapabilitySet declaredCaps; @@ -12990,8 +13032,8 @@ static inline void _dispatchCapabilitiesVisitorOfFunctionDecl( void SemanticsDeclCapabilityVisitor::visitFunctionDeclBase(FunctionDeclBase* funcDecl) { - // If the function is an entrypoint and specifies a target stage, add the capabilities to our - // function capabilities. + // If the function is an entrypoint and specifies a target stage, add the capabilities to + // our function capabilities. _dispatchCapabilitiesVisitorOfFunctionDecl( this, funcDecl, @@ -13012,8 +13054,8 @@ void SemanticsDeclCapabilityVisitor::visitFunctionDeclBase(FunctionDeclBase* fun auto vis = getDeclVisibility(funcDecl); - // If 0 capabilities were annotated on a function, capabilities are inferred from the function - // body + // If 0 capabilities were annotated on a function, capabilities are inferred from the + // function body if (declaredCaps.isEmpty()) { declaredCaps = funcDecl->inferredCapabilityRequirements; @@ -13133,7 +13175,8 @@ DeclVisibility getDeclVisibility(Decl* decl) : parentModule->defaultVisibility; } - // Members of other agg type decls will have their default visibility capped to the parents'. + // Members of other agg type decls will have their default visibility capped to the + // parents'. if (as<NamespaceDecl>(decl)) { return DeclVisibility::Public; @@ -13345,15 +13388,16 @@ void SemanticsDeclCapabilityVisitor::diagnoseUndeclaredCapability( // There are two causes for why type checking failed on failedAvailableSet. // The first scenario is that failedAvailableSet defines a set of capabilities on a - // compilation target (e.g. hlsl) that isn't defined by some callees, for example, if we have - // a function: + // compilation target (e.g. hlsl) that isn't defined by some callees, for example, if we + // have a function: // [require(hlsl)] // <-- failedAvailableSet // [require(cpp)] // void caller() // { // printf(); // assume this is defined for (cpp | cuda). // } - // In this case we should diagnose error reporting printf isn't defined on a required target. + // In this case we should diagnose error reporting printf isn't defined on a required + // target. // // Now, we detect if we are case 1. @@ -13370,8 +13414,8 @@ void SemanticsDeclCapabilityVisitor::diagnoseUndeclaredCapability( decl, outFailedAtom); - // Anything defined on a non-failed target atom may be the culprit to why we fail having - // a target capability. Print out all possible culprits. + // Anything defined on a non-failed target atom may be the culprit to why we fail + // having a target capability. Print out all possible culprits. CapabilityAtomSet failedAtomSet; failedAtomSet.add((UInt)outFailedAtom); CapabilityAtomSet targetsNotUsedSet; @@ -13395,9 +13439,12 @@ void SemanticsDeclCapabilityVisitor::diagnoseUndeclaredCapability( } } - //// The second scenario is when the callee is using a capability that is not provided by the - /// requirement. / For example: / [require(hlsl,b,c)] / void caller() / { / useD(); - ///// require capability (hlsl,d) / } / In this case we should report that useD() is using a + //// The second scenario is when the callee is using a capability that is not provided by + /// the + /// requirement. / For example: / [require(hlsl,b,c)] / void caller() / { / + /// useD(); + ///// require capability (hlsl,d) / } / In this case we should report that useD() is + /// using a /// capability that is not declared by caller. //// diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 350362a0e..59290f8ad 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -3066,4 +3066,17 @@ bool resolveStageOfProfileWithEntryPoint( const List<RefPtr<TargetRequest>>& targets, FuncDecl* entryPointFuncDecl, DiagnosticSink* sink); + +// For an extensions decl, collect a list of decls on which the extension might be applying to. +// For example, if we see a `extension Foo`, return a `Decl*` that represents `struct Foo`. +// In the case of free-form generic extensions i.e. `extension<T:IFoo> T : IBar`, return `IFoo`. +// These are the decls that we need to register the extension with in +// `mapTypeToCandidateExtensions`. +// Returns true when any base decls are found. +bool getExtensionTargetDeclList( + ASTBuilder* astBuilder, + DeclRefType* targetDeclRefType, + ExtensionDecl* extDeclRef, + ShortList<AggTypeDecl*>& targetDecls); + } // namespace Slang diff --git a/source/slang/slang-ir-hlsl-legalize.cpp b/source/slang/slang-ir-hlsl-legalize.cpp index ec2419985..7116d635e 100644 --- a/source/slang/slang-ir-hlsl-legalize.cpp +++ b/source/slang/slang-ir-hlsl-legalize.cpp @@ -36,7 +36,7 @@ void searchChildrenForForceVarIntoStructTemporarily(IRModule* module, IRInst* in continue; auto forceStructArg = arg->getOperand(0); auto forceStructBaseType = - as<IRType>(forceStructArg->getDataType()->getOperand(0)); + (IRType*)(forceStructArg->getDataType()->getOperand(0)); IRBuilder builder(call); if (forceStructBaseType->getOp() == kIROp_StructType) { diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 9c3892c0e..dbefa68c7 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -1143,7 +1143,7 @@ struct IRMixedDifferentialInstDecoration : IRAutodiffInstDecoration IRUse pairType; IR_LEAF_ISA(MixedDifferentialInstDecoration) - IRType* getPairType() { return as<IRType>(getOperand(0)); } + IRType* getPairType() { return (IRType*)(getOperand(0)); } }; struct IRRecomputeBlockDecoration : IRAutodiffInstDecoration diff --git a/source/slang/slang-ir-link.cpp b/source/slang/slang-ir-link.cpp index 364e58c48..53bbbda9e 100644 --- a/source/slang/slang-ir-link.cpp +++ b/source/slang/slang-ir-link.cpp @@ -614,8 +614,8 @@ IRWitnessTable* cloneWitnessTableImpl( IRWitnessTable* clonedTable = dstTable; if (!clonedTable) { - auto clonedBaseType = cloneType(context, as<IRType>(originalTable->getConformanceType())); - auto clonedSubType = cloneType(context, as<IRType>(originalTable->getConcreteType())); + auto clonedBaseType = cloneType(context, (IRType*)(originalTable->getConformanceType())); + auto clonedSubType = cloneType(context, (IRType*)(originalTable->getConcreteType())); clonedTable = builder->createWitnessTable(clonedBaseType, clonedSubType); } cloneSimpleGlobalValueImpl(context, originalTable, originalValues, clonedTable, registerValue); diff --git a/source/slang/slang-ir-use-uninitialized-values.cpp b/source/slang/slang-ir-use-uninitialized-values.cpp index 51de2117c..71aae5923 100644 --- a/source/slang/slang-ir-use-uninitialized-values.cpp +++ b/source/slang/slang-ir-use-uninitialized-values.cpp @@ -186,7 +186,7 @@ static bool canIgnoreType(IRType* type, IRType* upper) if (auto spec = as<IRSpecialize>(type)) { IRInst* inner = getResolvedInstForDecorations(spec); - IRType* innerType = as<IRType>(inner); + IRType* innerType = (IRType*)(inner); return canIgnoreType(innerType, upper); } diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 06cdf430b..ed8a52b9e 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -9453,10 +9453,10 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> // should handle propgation of value-size information from constraints // back to generic parameters? // - if (auto declRefType = as<DeclRefType>(constraintDecl->sub.type)) + if (auto genParamDeclRef = + isDeclRefTypeOf<GenericTypeParamDeclBase>(constraintDecl->sub.type)) { - auto typeParamDeclVal = - subContext->findLoweredDecl(declRefType->getDeclRef().getDecl()); + auto typeParamDeclVal = subContext->findLoweredDecl(genParamDeclRef.getDecl()); SLANG_ASSERT(typeParamDeclVal && typeParamDeclVal->val); subBuilder->addTypeConstraintDecoration(typeParamDeclVal->val, supType); } diff --git a/source/slang/slang-serialize-container.cpp b/source/slang/slang-serialize-container.cpp index e5775f526..b5121d373 100644 --- a/source/slang/slang-serialize-container.cpp +++ b/source/slang/slang-serialize-container.cpp @@ -5,6 +5,7 @@ #include "../core/slang-math.h" #include "../core/slang-stream.h" #include "../core/slang-text-io.h" +#include "slang-check-impl.h" #include "slang-compiler.h" #include "slang-mangled-lexer.h" #include "slang-parser.h" @@ -813,15 +814,16 @@ static List<ExtensionDecl*>& _getCandidateExtensionList( if (auto targetDeclRefType = as<DeclRefType>(extensionDecl->targetType)) { - // Attach our extension to that type as a candidate... - if (auto aggTypeDeclRef = - targetDeclRefType->getDeclRef() - .as<AggTypeDecl>()) + ShortList<AggTypeDecl*> baseDecls; + getExtensionTargetDeclList( + astBuilder, + targetDeclRefType, + extensionDecl, + baseDecls); + for (auto baseDecl : baseDecls) { - auto aggTypeDecl = aggTypeDeclRef.getDecl(); - _getCandidateExtensionList( - aggTypeDecl, + baseDecl, moduleDecl->mapTypeToCandidateExtensions) .add(extensionDecl); } |
