diff options
Diffstat (limited to 'source/slang/parameter-binding.cpp')
| -rw-r--r-- | source/slang/parameter-binding.cpp | 354 |
1 files changed, 341 insertions, 13 deletions
diff --git a/source/slang/parameter-binding.cpp b/source/slang/parameter-binding.cpp index e1c5c1aca..37642ac81 100644 --- a/source/slang/parameter-binding.cpp +++ b/source/slang/parameter-binding.cpp @@ -351,14 +351,347 @@ LayoutSemanticInfo ExtractLayoutSemanticInfo( return info; } +static Name* getReflectionName(VarDeclBase* varDecl) +{ + if (auto reflectionNameModifier = varDecl->FindModifier<ParameterGroupReflectionName>()) + return reflectionNameModifier->nameAndLoc.name; + + return varDecl->getName(); +} + +// Information tracked when doing a structural +// match of types. +struct StructuralTypeMatchStack +{ + DeclRef<VarDeclBase> leftDecl; + DeclRef<VarDeclBase> rightDecl; + StructuralTypeMatchStack* parent; +}; + +static void diagnoseParameterTypeMismatch( + ParameterBindingContext* context, + StructuralTypeMatchStack* inStack) +{ + assert(inStack); + + // The bottom-most entry in the stack should represent + // the shader parameters that kicked things off + auto stack = inStack; + while(stack->parent) + stack = stack->parent; + + getSink(context)->diagnose(stack->leftDecl, Diagnostics::shaderParameterDeclarationsDontMatch, getReflectionName(stack->leftDecl)); + getSink(context)->diagnose(stack->rightDecl, Diagnostics::seeOtherDeclarationOf, getReflectionName(stack->rightDecl)); +} + +// Two types that were expected to match did not. +// Inform the user with a suitable message. +static void diagnoseTypeMismatch( + ParameterBindingContext* context, + StructuralTypeMatchStack* inStack) +{ + auto stack = inStack; + assert(stack); + diagnoseParameterTypeMismatch(context, stack); + + auto leftType = GetType(stack->leftDecl); + auto rightType = GetType(stack->rightDecl); + + if( stack->parent ) + { + getSink(context)->diagnose(stack->leftDecl, Diagnostics::fieldTypeMisMatch, getReflectionName(stack->leftDecl), leftType, rightType); + getSink(context)->diagnose(stack->rightDecl, Diagnostics::seeOtherDeclarationOf, getReflectionName(stack->rightDecl)); + + stack = stack->parent; + if( stack ) + { + while( stack->parent ) + { + getSink(context)->diagnose(stack->leftDecl, Diagnostics::usedInDeclarationOf, getReflectionName(stack->leftDecl)); + stack = stack->parent; + } + } + } + else + { + getSink(context)->diagnose(stack->leftDecl, Diagnostics::shaderParameterTypeMismatch, leftType, rightType); + } +} + +// Two types that were expected to match did not. +// Inform the user with a suitable message. +static void diagnoseTypeFieldsMismatch( + ParameterBindingContext* context, + DeclRef<Decl> const& left, + DeclRef<Decl> const& right, + StructuralTypeMatchStack* stack) +{ + diagnoseParameterTypeMismatch(context, stack); + + getSink(context)->diagnose(left, Diagnostics::fieldDeclarationsDontMatch, left.GetName()); + getSink(context)->diagnose(right, Diagnostics::seeOtherDeclarationOf, right.GetName()); + + if( stack ) + { + while( stack->parent ) + { + getSink(context)->diagnose(stack->leftDecl, Diagnostics::usedInDeclarationOf, getReflectionName(stack->leftDecl)); + stack = stack->parent; + } + } +} + +static void collectFields( + DeclRef<AggTypeDecl> declRef, + List<DeclRef<StructField>>& outFields) +{ + for( auto fieldDeclRef : getMembersOfType<StructField>(declRef) ) + { + if(fieldDeclRef.getDecl()->HasModifier<HLSLStaticModifier>()) + continue; + + outFields.Add(fieldDeclRef); + } +} + +static bool validateTypesMatch( + ParameterBindingContext* context, + Type* left, + Type* right, + StructuralTypeMatchStack* stack); + +static bool validateIntValuesMatch( + ParameterBindingContext* context, + IntVal* left, + IntVal* right, + StructuralTypeMatchStack* stack) +{ + if(left->EqualsVal(right)) + return true; + + // TODO: are there other cases we need to handle here? + + diagnoseTypeMismatch(context, stack); + return false; +} + + +static bool validateValuesMatch( + ParameterBindingContext* context, + Val* left, + Val* right, + StructuralTypeMatchStack* stack) +{ + if( auto leftType = dynamic_cast<Type*>(left) ) + { + if( auto rightType = dynamic_cast<Type*>(right) ) + { + return validateTypesMatch(context, leftType, rightType, stack); + } + } + + if( auto leftInt = dynamic_cast<IntVal*>(left) ) + { + if( auto rightInt = dynamic_cast<IntVal*>(right) ) + { + return validateIntValuesMatch(context, leftInt, rightInt, stack); + } + } + + if( auto leftWitness = dynamic_cast<SubtypeWitness*>(left) ) + { + if( auto rightWitness = dynamic_cast<SubtypeWitness*>(right) ) + { + return true; + } + } + + diagnoseTypeMismatch(context, stack); + return false; +} + +static bool validateGenericSubstitutionsMatch( + ParameterBindingContext* context, + GenericSubstitution* left, + GenericSubstitution* right, + StructuralTypeMatchStack* stack) +{ + if( !left ) + { + if( !right ) + { + return true; + } + + diagnoseTypeMismatch(context, stack); + return false; + } + + + + UInt argCount = left->args.Count(); + if( argCount != right->args.Count() ) + { + diagnoseTypeMismatch(context, stack); + return false; + } + + for( UInt aa = 0; aa < argCount; ++aa ) + { + auto leftArg = left->args[aa]; + auto rightArg = right->args[aa]; + + if(!validateValuesMatch(context, leftArg, rightArg, stack)) + return false; + } + + return true; +} + +static bool validateSpecializationsMatch( + ParameterBindingContext* context, + SubstitutionSet left, + SubstitutionSet right, + StructuralTypeMatchStack* stack) +{ + if(!validateGenericSubstitutionsMatch( + context, + left.genericSubstitutions, + right.genericSubstitutions, + stack)) + { + return false; + } + + // TODO: anything else to match? + + return true; +} + +// Determine if two types "match" for the purposes of `cbuffer` layout rules. +// +static bool validateTypesMatch( + ParameterBindingContext* context, + Type* left, + Type* right, + StructuralTypeMatchStack* stack) +{ + if(left->Equals(right)) + return true; + + // It is possible that the types don't match exactly, but + // they *do* match structurally. + + // Note: the following code will lead to infinite recursion if there + // are ever recursive types. We'd need a more refined system to + // cache the matches we've already found. + + if( auto leftDeclRefType = left->As<DeclRefType>() ) + { + if( auto rightDeclRefType = right->As<DeclRefType>() ) + { + // Are they references to matching decl refs? + auto leftDeclRef = leftDeclRefType->declRef; + auto rightDeclRef = rightDeclRefType->declRef; + + // Do the reference the same declaration? Or declarations + // with the same name? + // + // TODO: we should only consider the same-name case if the + // declarations come from translation units being compiled + // (and not an imported module). + if( leftDeclRef.getDecl() == rightDeclRef.getDecl() + || leftDeclRef.GetName() == rightDeclRef.GetName() ) + { + // Check that any generic arguments match + if( !validateSpecializationsMatch( + context, + leftDeclRef.substitutions, + rightDeclRef.substitutions, + stack) ) + { + return false; + } + + // Check that any declared fields match too. + if( auto leftStructDeclRef = leftDeclRef.As<AggTypeDecl>() ) + { + if( auto rightStructDeclRef = rightDeclRef.As<AggTypeDecl>() ) + { + List<DeclRef<StructField>> leftFields; + List<DeclRef<StructField>> rightFields; + + collectFields(leftStructDeclRef, leftFields); + collectFields(rightStructDeclRef, rightFields); + + UInt leftFieldCount = leftFields.Count(); + UInt rightFieldCount = rightFields.Count(); + + if( leftFieldCount != rightFieldCount ) + { + diagnoseTypeFieldsMismatch(context, leftDeclRef, rightDeclRef, stack); + return false; + } + + for( UInt ii = 0; ii < leftFieldCount; ++ii ) + { + auto leftField = leftFields[ii]; + auto rightField = rightFields[ii]; + + if( leftField.GetName() != rightField.GetName() ) + { + diagnoseTypeFieldsMismatch(context, leftDeclRef, rightDeclRef, stack); + return false; + } + + auto leftFieldType = GetType(leftField); + auto rightFieldType = GetType(rightField); + + StructuralTypeMatchStack subStack; + subStack.parent = stack; + subStack.leftDecl = leftField; + subStack.rightDecl = rightField; + + if(!validateTypesMatch(context, leftFieldType,rightFieldType, &subStack)) + return false; + } + } + } + + // Everything seemed to match recursively. + return true; + } + } + } + + // If we are looking at `T[N]` and `U[M]` we want to check that + // `T` is structurally equivalent to `U` and `N` is the same as `M`. + else if( auto leftArrayType = left->As<ArrayExpressionType>() ) + { + if( auto rightArrayType = right->As<ArrayExpressionType>() ) + { + if(!validateTypesMatch(context, leftArrayType->baseType, rightArrayType->baseType, stack) ) + return false; + + if(!validateValuesMatch(context, leftArrayType->ArrayLength, rightArrayType->ArrayLength, stack)) + return false; + + return true; + } + } + + diagnoseTypeMismatch(context, stack); + return false; +} + // This function is supposed to determine if two global shader // parameter declarations represent the same logical parameter // (so that they should get the exact same binding(s) allocated). // static bool doesParameterMatch( - ParameterBindingContext*, + ParameterBindingContext* context, RefPtr<VarLayout> varLayout, - ParameterInfo*) + ParameterInfo* parameterInfo) { // Any "varying" parameter should automatically be excluded // @@ -378,9 +711,12 @@ static bool doesParameterMatch( } } - // TODO: this is where we should apply a more detailed - // matching process, to check that the existing - // declarations conform to the same basic layout. + StructuralTypeMatchStack stack; + stack.parent = nullptr; + stack.leftDecl = varLayout->varDecl; + stack.rightDecl = parameterInfo->varLayouts[0]->varDecl; + + validateTypesMatch(context, varLayout->typeLayout->type, parameterInfo->varLayouts[0]->typeLayout->type, &stack); return true; } @@ -415,14 +751,6 @@ static bool findLayoutArg( // -static Name* getReflectionName(VarDeclBase* varDecl) -{ - if (auto reflectionNameModifier = varDecl->FindModifier<ParameterGroupReflectionName>()) - return reflectionNameModifier->nameAndLoc.name; - - return varDecl->getName(); -} - static bool isGLSLBuiltinName(VarDeclBase* varDecl) { return getText(getReflectionName(varDecl)).StartsWith("gl_"); |
