diff options
Diffstat (limited to 'source/slang/check.cpp')
| -rw-r--r-- | source/slang/check.cpp | 515 |
1 files changed, 504 insertions, 11 deletions
diff --git a/source/slang/check.cpp b/source/slang/check.cpp index a9f84c5c3..998324612 100644 --- a/source/slang/check.cpp +++ b/source/slang/check.cpp @@ -9306,7 +9306,7 @@ namespace Slang } /// Recursively walk `paramDeclRef` and add any required existential slots to `ioSlots`. - static void _collectExistentialParamsRec( + static void _collectExistentialSlotsRec( ExistentialSlots& ioSlots, DeclRef<VarDeclBase> paramDeclRef) { @@ -9339,7 +9339,7 @@ namespace Slang if(fieldDeclRef.getDecl()->HasModifier<HLSLStaticModifier>()) continue; - _collectExistentialParamsRec(ioSlots, fieldDeclRef); + _collectExistentialSlotsRec(ioSlots, fieldDeclRef); } } } @@ -9349,11 +9349,26 @@ namespace Slang // element types. } + /// Add information about a shader parameter to `ioParams` and `ioSlots` + static void _collectExistentialSlotsForShaderParam( + ShaderParamInfo& ioParamInfo, + ExistentialSlots& ioSlots, + DeclRef<VarDeclBase> paramDeclRef) + { + UInt startSlot = ioSlots.types.Count(); + _collectExistentialSlotsRec(ioSlots, paramDeclRef); + UInt endSlot = ioSlots.types.Count(); + UInt slotCount = endSlot - startSlot; + + ioParamInfo.firstExistentialTypeSlot = startSlot; + ioParamInfo.existentialTypeSlotCount = slotCount; + } + /// Enumerate the existential-type parameters of an `EntryPoint`. /// /// Any parameters found will be added to the list of existential slots on `this`. /// - void EntryPoint::_collectExistentialParams() + void EntryPoint::_collectShaderParams() { // Note: we defensively test whether there is a function decl-ref // because this routine gets called from the constructor, and @@ -9363,7 +9378,15 @@ namespace Slang { for( auto paramDeclRef : GetParameters(funcDeclRef) ) { - _collectExistentialParamsRec(m_existentialSlots, paramDeclRef); + ShaderParamInfo shaderParamInfo; + shaderParamInfo.paramDeclRef = paramDeclRef; + + _collectExistentialSlotsForShaderParam( + shaderParamInfo, + m_existentialSlots, + paramDeclRef); + + m_shaderParams.Add(shaderParamInfo); } } } @@ -9644,16 +9667,433 @@ namespace Slang return entryPoint; } + /// Get the name a variable will use for reflection purposes +Name* getReflectionName(VarDeclBase* varDecl) +{ + if (auto reflectionNameModifier = varDecl->FindModifier<ParameterGroupReflectionName>()) + return reflectionNameModifier->nameAndLoc.name; + + return varDecl->getName(); +} + +// Information tracked when doing a structural +// match of types. +struct StructuralTypeMatchStack +{ + DeclRef<VarDeclBase> leftDecl; + DeclRef<VarDeclBase> rightDecl; + StructuralTypeMatchStack* parent; +}; + +static void diagnoseParameterTypeMismatch( + DiagnosticSink* sink, + StructuralTypeMatchStack* inStack) +{ + SLANG_ASSERT(inStack); + + // The bottom-most entry in the stack should represent + // the shader parameters that kicked things off + auto stack = inStack; + while(stack->parent) + stack = stack->parent; + + sink->diagnose(stack->leftDecl, Diagnostics::shaderParameterDeclarationsDontMatch, getReflectionName(stack->leftDecl)); + sink->diagnose(stack->rightDecl, Diagnostics::seeOtherDeclarationOf, getReflectionName(stack->rightDecl)); +} + +// Two types that were expected to match did not. +// Inform the user with a suitable message. +static void diagnoseTypeMismatch( + DiagnosticSink* sink, + StructuralTypeMatchStack* inStack) +{ + auto stack = inStack; + SLANG_ASSERT(stack); + diagnoseParameterTypeMismatch(sink, stack); + + auto leftType = GetType(stack->leftDecl); + auto rightType = GetType(stack->rightDecl); + + if( stack->parent ) + { + sink->diagnose(stack->leftDecl, Diagnostics::fieldTypeMisMatch, getReflectionName(stack->leftDecl), leftType, rightType); + sink->diagnose(stack->rightDecl, Diagnostics::seeOtherDeclarationOf, getReflectionName(stack->rightDecl)); + + stack = stack->parent; + if( stack ) + { + while( stack->parent ) + { + sink->diagnose(stack->leftDecl, Diagnostics::usedInDeclarationOf, getReflectionName(stack->leftDecl)); + stack = stack->parent; + } + } + } + else + { + sink->diagnose(stack->leftDecl, Diagnostics::shaderParameterTypeMismatch, leftType, rightType); + } +} + +// Two types that were expected to match did not. +// Inform the user with a suitable message. +static void diagnoseTypeFieldsMismatch( + DiagnosticSink* sink, + DeclRef<Decl> const& left, + DeclRef<Decl> const& right, + StructuralTypeMatchStack* stack) +{ + diagnoseParameterTypeMismatch(sink, stack); + + sink->diagnose(left, Diagnostics::fieldDeclarationsDontMatch, left.GetName()); + sink->diagnose(right, Diagnostics::seeOtherDeclarationOf, right.GetName()); + + if( stack ) + { + while( stack->parent ) + { + sink->diagnose(stack->leftDecl, Diagnostics::usedInDeclarationOf, getReflectionName(stack->leftDecl)); + stack = stack->parent; + } + } +} + +static void collectFields( + DeclRef<AggTypeDecl> declRef, + List<DeclRef<VarDecl>>& outFields) +{ + for( auto fieldDeclRef : getMembersOfType<VarDecl>(declRef) ) + { + if(fieldDeclRef.getDecl()->HasModifier<HLSLStaticModifier>()) + continue; + + outFields.Add(fieldDeclRef); + } +} + +static bool validateTypesMatch( + DiagnosticSink* sink, + Type* left, + Type* right, + StructuralTypeMatchStack* stack); + +static bool validateIntValuesMatch( + DiagnosticSink* sink, + IntVal* left, + IntVal* right, + StructuralTypeMatchStack* stack) +{ + if(left->EqualsVal(right)) + return true; + + // TODO: are there other cases we need to handle here? + + diagnoseTypeMismatch(sink, stack); + return false; +} + + +static bool validateValuesMatch( + DiagnosticSink* sink, + Val* left, + Val* right, + StructuralTypeMatchStack* stack) +{ + if( auto leftType = dynamicCast<Type>(left) ) + { + if( auto rightType = dynamicCast<Type>(right) ) + { + return validateTypesMatch(sink, leftType, rightType, stack); + } + } + + if( auto leftInt = dynamicCast<IntVal>(left) ) + { + if( auto rightInt = dynamicCast<IntVal>(right) ) + { + return validateIntValuesMatch(sink, leftInt, rightInt, stack); + } + } + + if( auto leftWitness = dynamicCast<SubtypeWitness>(left) ) + { + if( auto rightWitness = dynamicCast<SubtypeWitness>(right) ) + { + return true; + } + } + + diagnoseTypeMismatch(sink, stack); + return false; +} + +static bool validateGenericSubstitutionsMatch( + DiagnosticSink* sink, + GenericSubstitution* left, + GenericSubstitution* right, + StructuralTypeMatchStack* stack) +{ + if( !left ) + { + if( !right ) + { + return true; + } + + diagnoseTypeMismatch(sink, stack); + return false; + } + + + + UInt argCount = left->args.Count(); + if( argCount != right->args.Count() ) + { + diagnoseTypeMismatch(sink, stack); + return false; + } + + for( UInt aa = 0; aa < argCount; ++aa ) + { + auto leftArg = left->args[aa]; + auto rightArg = right->args[aa]; + + if(!validateValuesMatch(sink, leftArg, rightArg, stack)) + return false; + } + + return true; +} + +static bool validateThisTypeSubstitutionsMatch( + DiagnosticSink* /*sink*/, + ThisTypeSubstitution* /*left*/, + ThisTypeSubstitution* /*right*/, + StructuralTypeMatchStack* /*stack*/) +{ + // TODO: actual checking. + return true; +} + +static bool validateSpecializationsMatch( + DiagnosticSink* sink, + SubstitutionSet left, + SubstitutionSet right, + StructuralTypeMatchStack* stack) +{ + auto ll = left.substitutions; + auto rr = right.substitutions; + for(;;) + { + // Skip any global generic substitutions. + if(auto leftGlobalGeneric = as<GlobalGenericParamSubstitution>(ll)) + { + ll = leftGlobalGeneric->outer; + continue; + } + if(auto rightGlobalGeneric = as<GlobalGenericParamSubstitution>(rr)) + { + rr = rightGlobalGeneric->outer; + continue; + } + + // If either ran out, then we expect both to have run out. + if(!ll || !rr) + return !ll && !rr; + + auto leftSubst = ll; + auto rightSubst = rr; + + ll = ll->outer; + rr = rr->outer; + + if(auto leftGeneric = as<GenericSubstitution>(leftSubst)) + { + if(auto rightGeneric = as<GenericSubstitution>(rightSubst)) + { + if(validateGenericSubstitutionsMatch(sink, leftGeneric, rightGeneric, stack)) + { + continue; + } + } + } + else if(auto leftThisType = as<ThisTypeSubstitution>(leftSubst)) + { + if(auto rightThisType = as<ThisTypeSubstitution>(rightSubst)) + { + if(validateThisTypeSubstitutionsMatch(sink, leftThisType, rightThisType, stack)) + { + continue; + } + } + } + + return false; + } + + return true; +} + +// Determine if two types "match" for the purposes of `cbuffer` layout rules. +// +static bool validateTypesMatch( + DiagnosticSink* sink, + Type* left, + Type* right, + StructuralTypeMatchStack* stack) +{ + if(left->Equals(right)) + return true; + + // It is possible that the types don't match exactly, but + // they *do* match structurally. + + // Note: the following code will lead to infinite recursion if there + // are ever recursive types. We'd need a more refined system to + // cache the matches we've already found. + + if( auto leftDeclRefType = as<DeclRefType>(left) ) + { + if( auto rightDeclRefType = as<DeclRefType>(right) ) + { + // Are they references to matching decl refs? + auto leftDeclRef = leftDeclRefType->declRef; + auto rightDeclRef = rightDeclRefType->declRef; + + // Do the reference the same declaration? Or declarations + // with the same name? + // + // TODO: we should only consider the same-name case if the + // declarations come from translation units being compiled + // (and not an imported module). + if( leftDeclRef.getDecl() == rightDeclRef.getDecl() + || leftDeclRef.GetName() == rightDeclRef.GetName() ) + { + // Check that any generic arguments match + if( !validateSpecializationsMatch( + sink, + leftDeclRef.substitutions, + rightDeclRef.substitutions, + stack) ) + { + return false; + } + + // Check that any declared fields match too. + if( auto leftStructDeclRef = leftDeclRef.as<AggTypeDecl>() ) + { + if( auto rightStructDeclRef = rightDeclRef.as<AggTypeDecl>() ) + { + List<DeclRef<VarDecl>> leftFields; + List<DeclRef<VarDecl>> rightFields; + + collectFields(leftStructDeclRef, leftFields); + collectFields(rightStructDeclRef, rightFields); + + UInt leftFieldCount = leftFields.Count(); + UInt rightFieldCount = rightFields.Count(); + + if( leftFieldCount != rightFieldCount ) + { + diagnoseTypeFieldsMismatch(sink, leftDeclRef, rightDeclRef, stack); + return false; + } + + for( UInt ii = 0; ii < leftFieldCount; ++ii ) + { + auto leftField = leftFields[ii]; + auto rightField = rightFields[ii]; + + if( leftField.GetName() != rightField.GetName() ) + { + diagnoseTypeFieldsMismatch(sink, leftDeclRef, rightDeclRef, stack); + return false; + } + + auto leftFieldType = GetType(leftField); + auto rightFieldType = GetType(rightField); + + StructuralTypeMatchStack subStack; + subStack.parent = stack; + subStack.leftDecl = leftField; + subStack.rightDecl = rightField; + + if(!validateTypesMatch(sink, leftFieldType,rightFieldType, &subStack)) + return false; + } + } + } + + // Everything seemed to match recursively. + return true; + } + } + } + + // If we are looking at `T[N]` and `U[M]` we want to check that + // `T` is structurally equivalent to `U` and `N` is the same as `M`. + else if( auto leftArrayType = as<ArrayExpressionType>(left) ) + { + if( auto rightArrayType = as<ArrayExpressionType>(right) ) + { + if(!validateTypesMatch(sink, leftArrayType->baseType, rightArrayType->baseType, stack) ) + return false; + + if(!validateValuesMatch(sink, leftArrayType->ArrayLength, rightArrayType->ArrayLength, stack)) + return false; + + return true; + } + } + + diagnoseTypeMismatch(sink, stack); + return false; +} + +// This function is supposed to determine if two global shader +// parameter declarations represent the same logical parameter +// (so that they should get the exact same binding(s) allocated). +// +static bool doesParameterMatch( + DiagnosticSink* sink, + DeclRef<VarDeclBase> varDeclRef, + DeclRef<VarDeclBase> existingVarDeclRef) +{ + StructuralTypeMatchStack stack; + stack.parent = nullptr; + stack.leftDecl = varDeclRef; + stack.rightDecl = existingVarDeclRef; + + validateTypesMatch(sink, GetType(varDeclRef), GetType(existingVarDeclRef), &stack); + + return true; +} + + + + /// Enumerate the existential-type parameters of a `Program`. /// /// Any parameters found will be added to the list of existential slots on `this`. /// - void Program::_collectExistentialParams() + void Program::_collectShaderParams(DiagnosticSink* sink) { - // We need to inspect all of the global shader parameters + // We need to collect all of the global shader parameters // referenced by the compile request, and for each we - // need to determine what existential types parameters it implies. + // need to do a few things: + // + // * We need to determine if the parameter is a duplicate/redeclaration + // of the "same" parameter in another translation unit, and collapse + // those into one logical shader parameter if so. + // + // * We need to determine what existential type slots are introduced + // by the parameter, and associate that information with the parameter. + // + // To deal with the first issue, we will maintain a map from a parameter + // name to the index of an existing parameter with that name. // + Dictionary<Name*, Int> mapNameToParamIndex; + for( auto module : getModuleDependencies() ) { auto moduleDecl = module->getModuleDecl(); @@ -9662,9 +10102,55 @@ namespace Slang if(!isGlobalShaderParameter(globalVar)) continue; - _collectExistentialParamsRec( + // This declaration may represent the same logical parameter + // as a declaration that came from a different translation unit. + // If that is the case, we want to re-use the same `ShaderParamInfo` + // across both parameters. + // + // TODO: This logic currently detects *any* global-scope parameters + // with matching names, but it should eventually be narrowly + // scoped so that it only applies to parameters from unnamed modules + // (that is, modules that represent directly-compiled shader files + // and not `import`ed code). + // + // First we look for an existing entry matching the name + // of this parameter: + // + auto paramName = getReflectionName(globalVar); + Int existingParamIndex = -1; + if( mapNameToParamIndex.TryGetValue(paramName, existingParamIndex) ) + { + // If the parameters have the same name, but don't "match" according to some reasonable rules, + // then we will treat them as distinct global parameters. + // + // Note: all of the mismatch cases currently report errors, so that + // compilation will fail on a mismatch. + // + auto& existingParam = m_shaderParams[existingParamIndex]; + if( doesParameterMatch(sink, makeDeclRef(globalVar.Ptr()), existingParam.paramDeclRef) ) + { + // If we hit this case, then we had a match, and we should + // consider the new variable to be a redclaration of + // the existing one. + + existingParam.additionalParamDeclRefs.Add( + makeDeclRef(globalVar.Ptr())); + continue; + } + } + + Int newParamIndex = Int(m_shaderParams.Count()); + mapNameToParamIndex.Add(paramName, newParamIndex); + + GlobalShaderParamInfo shaderParamInfo; + shaderParamInfo.paramDeclRef = makeDeclRef(globalVar.Ptr()); + + _collectExistentialSlotsForShaderParam( + shaderParamInfo, m_globalExistentialSlots, makeDeclRef(globalVar.Ptr())); + + m_shaderParams.Add(shaderParamInfo); } } } @@ -9795,7 +10281,7 @@ namespace Slang } } - program->_collectExistentialParams(); + program->_collectShaderParams(sink); return program; } @@ -10181,8 +10667,15 @@ namespace Slang specializedProgram->setGlobalGenericSubsitution(globalGenericSubsts); - // Now deal with the existential arguments - specializedProgram->_collectExistentialParams(); + // Now deal with the shader parameters and existential arguments + // + // Note: We should in theory be able to just copy over the shader + // parameters and existential slot information from the unspecialized + // program. This could save some time, but it would also mean that + // the only way to create a specialized program is by creating an + // unspecialized on first, which is maybe not always desirable. + // + specializedProgram->_collectShaderParams(sink); specializedProgram->_specializeExistentialSlots(globalExistentialArgs, sink); return specializedProgram; |
