summaryrefslogtreecommitdiffstats
path: root/source/slang/parameter-binding.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/parameter-binding.cpp')
-rw-r--r--source/slang/parameter-binding.cpp354
1 files changed, 341 insertions, 13 deletions
diff --git a/source/slang/parameter-binding.cpp b/source/slang/parameter-binding.cpp
index e1c5c1aca..37642ac81 100644
--- a/source/slang/parameter-binding.cpp
+++ b/source/slang/parameter-binding.cpp
@@ -351,14 +351,347 @@ LayoutSemanticInfo ExtractLayoutSemanticInfo(
return info;
}
+static Name* getReflectionName(VarDeclBase* varDecl)
+{
+ if (auto reflectionNameModifier = varDecl->FindModifier<ParameterGroupReflectionName>())
+ return reflectionNameModifier->nameAndLoc.name;
+
+ return varDecl->getName();
+}
+
+// Information tracked when doing a structural
+// match of types.
+struct StructuralTypeMatchStack
+{
+ DeclRef<VarDeclBase> leftDecl;
+ DeclRef<VarDeclBase> rightDecl;
+ StructuralTypeMatchStack* parent;
+};
+
+static void diagnoseParameterTypeMismatch(
+ ParameterBindingContext* context,
+ StructuralTypeMatchStack* inStack)
+{
+ assert(inStack);
+
+ // The bottom-most entry in the stack should represent
+ // the shader parameters that kicked things off
+ auto stack = inStack;
+ while(stack->parent)
+ stack = stack->parent;
+
+ getSink(context)->diagnose(stack->leftDecl, Diagnostics::shaderParameterDeclarationsDontMatch, getReflectionName(stack->leftDecl));
+ getSink(context)->diagnose(stack->rightDecl, Diagnostics::seeOtherDeclarationOf, getReflectionName(stack->rightDecl));
+}
+
+// Two types that were expected to match did not.
+// Inform the user with a suitable message.
+static void diagnoseTypeMismatch(
+ ParameterBindingContext* context,
+ StructuralTypeMatchStack* inStack)
+{
+ auto stack = inStack;
+ assert(stack);
+ diagnoseParameterTypeMismatch(context, stack);
+
+ auto leftType = GetType(stack->leftDecl);
+ auto rightType = GetType(stack->rightDecl);
+
+ if( stack->parent )
+ {
+ getSink(context)->diagnose(stack->leftDecl, Diagnostics::fieldTypeMisMatch, getReflectionName(stack->leftDecl), leftType, rightType);
+ getSink(context)->diagnose(stack->rightDecl, Diagnostics::seeOtherDeclarationOf, getReflectionName(stack->rightDecl));
+
+ stack = stack->parent;
+ if( stack )
+ {
+ while( stack->parent )
+ {
+ getSink(context)->diagnose(stack->leftDecl, Diagnostics::usedInDeclarationOf, getReflectionName(stack->leftDecl));
+ stack = stack->parent;
+ }
+ }
+ }
+ else
+ {
+ getSink(context)->diagnose(stack->leftDecl, Diagnostics::shaderParameterTypeMismatch, leftType, rightType);
+ }
+}
+
+// Two types that were expected to match did not.
+// Inform the user with a suitable message.
+static void diagnoseTypeFieldsMismatch(
+ ParameterBindingContext* context,
+ DeclRef<Decl> const& left,
+ DeclRef<Decl> const& right,
+ StructuralTypeMatchStack* stack)
+{
+ diagnoseParameterTypeMismatch(context, stack);
+
+ getSink(context)->diagnose(left, Diagnostics::fieldDeclarationsDontMatch, left.GetName());
+ getSink(context)->diagnose(right, Diagnostics::seeOtherDeclarationOf, right.GetName());
+
+ if( stack )
+ {
+ while( stack->parent )
+ {
+ getSink(context)->diagnose(stack->leftDecl, Diagnostics::usedInDeclarationOf, getReflectionName(stack->leftDecl));
+ stack = stack->parent;
+ }
+ }
+}
+
+static void collectFields(
+ DeclRef<AggTypeDecl> declRef,
+ List<DeclRef<StructField>>& outFields)
+{
+ for( auto fieldDeclRef : getMembersOfType<StructField>(declRef) )
+ {
+ if(fieldDeclRef.getDecl()->HasModifier<HLSLStaticModifier>())
+ continue;
+
+ outFields.Add(fieldDeclRef);
+ }
+}
+
+static bool validateTypesMatch(
+ ParameterBindingContext* context,
+ Type* left,
+ Type* right,
+ StructuralTypeMatchStack* stack);
+
+static bool validateIntValuesMatch(
+ ParameterBindingContext* context,
+ IntVal* left,
+ IntVal* right,
+ StructuralTypeMatchStack* stack)
+{
+ if(left->EqualsVal(right))
+ return true;
+
+ // TODO: are there other cases we need to handle here?
+
+ diagnoseTypeMismatch(context, stack);
+ return false;
+}
+
+
+static bool validateValuesMatch(
+ ParameterBindingContext* context,
+ Val* left,
+ Val* right,
+ StructuralTypeMatchStack* stack)
+{
+ if( auto leftType = dynamic_cast<Type*>(left) )
+ {
+ if( auto rightType = dynamic_cast<Type*>(right) )
+ {
+ return validateTypesMatch(context, leftType, rightType, stack);
+ }
+ }
+
+ if( auto leftInt = dynamic_cast<IntVal*>(left) )
+ {
+ if( auto rightInt = dynamic_cast<IntVal*>(right) )
+ {
+ return validateIntValuesMatch(context, leftInt, rightInt, stack);
+ }
+ }
+
+ if( auto leftWitness = dynamic_cast<SubtypeWitness*>(left) )
+ {
+ if( auto rightWitness = dynamic_cast<SubtypeWitness*>(right) )
+ {
+ return true;
+ }
+ }
+
+ diagnoseTypeMismatch(context, stack);
+ return false;
+}
+
+static bool validateGenericSubstitutionsMatch(
+ ParameterBindingContext* context,
+ GenericSubstitution* left,
+ GenericSubstitution* right,
+ StructuralTypeMatchStack* stack)
+{
+ if( !left )
+ {
+ if( !right )
+ {
+ return true;
+ }
+
+ diagnoseTypeMismatch(context, stack);
+ return false;
+ }
+
+
+
+ UInt argCount = left->args.Count();
+ if( argCount != right->args.Count() )
+ {
+ diagnoseTypeMismatch(context, stack);
+ return false;
+ }
+
+ for( UInt aa = 0; aa < argCount; ++aa )
+ {
+ auto leftArg = left->args[aa];
+ auto rightArg = right->args[aa];
+
+ if(!validateValuesMatch(context, leftArg, rightArg, stack))
+ return false;
+ }
+
+ return true;
+}
+
+static bool validateSpecializationsMatch(
+ ParameterBindingContext* context,
+ SubstitutionSet left,
+ SubstitutionSet right,
+ StructuralTypeMatchStack* stack)
+{
+ if(!validateGenericSubstitutionsMatch(
+ context,
+ left.genericSubstitutions,
+ right.genericSubstitutions,
+ stack))
+ {
+ return false;
+ }
+
+ // TODO: anything else to match?
+
+ return true;
+}
+
+// Determine if two types "match" for the purposes of `cbuffer` layout rules.
+//
+static bool validateTypesMatch(
+ ParameterBindingContext* context,
+ Type* left,
+ Type* right,
+ StructuralTypeMatchStack* stack)
+{
+ if(left->Equals(right))
+ return true;
+
+ // It is possible that the types don't match exactly, but
+ // they *do* match structurally.
+
+ // Note: the following code will lead to infinite recursion if there
+ // are ever recursive types. We'd need a more refined system to
+ // cache the matches we've already found.
+
+ if( auto leftDeclRefType = left->As<DeclRefType>() )
+ {
+ if( auto rightDeclRefType = right->As<DeclRefType>() )
+ {
+ // Are they references to matching decl refs?
+ auto leftDeclRef = leftDeclRefType->declRef;
+ auto rightDeclRef = rightDeclRefType->declRef;
+
+ // Do the reference the same declaration? Or declarations
+ // with the same name?
+ //
+ // TODO: we should only consider the same-name case if the
+ // declarations come from translation units being compiled
+ // (and not an imported module).
+ if( leftDeclRef.getDecl() == rightDeclRef.getDecl()
+ || leftDeclRef.GetName() == rightDeclRef.GetName() )
+ {
+ // Check that any generic arguments match
+ if( !validateSpecializationsMatch(
+ context,
+ leftDeclRef.substitutions,
+ rightDeclRef.substitutions,
+ stack) )
+ {
+ return false;
+ }
+
+ // Check that any declared fields match too.
+ if( auto leftStructDeclRef = leftDeclRef.As<AggTypeDecl>() )
+ {
+ if( auto rightStructDeclRef = rightDeclRef.As<AggTypeDecl>() )
+ {
+ List<DeclRef<StructField>> leftFields;
+ List<DeclRef<StructField>> rightFields;
+
+ collectFields(leftStructDeclRef, leftFields);
+ collectFields(rightStructDeclRef, rightFields);
+
+ UInt leftFieldCount = leftFields.Count();
+ UInt rightFieldCount = rightFields.Count();
+
+ if( leftFieldCount != rightFieldCount )
+ {
+ diagnoseTypeFieldsMismatch(context, leftDeclRef, rightDeclRef, stack);
+ return false;
+ }
+
+ for( UInt ii = 0; ii < leftFieldCount; ++ii )
+ {
+ auto leftField = leftFields[ii];
+ auto rightField = rightFields[ii];
+
+ if( leftField.GetName() != rightField.GetName() )
+ {
+ diagnoseTypeFieldsMismatch(context, leftDeclRef, rightDeclRef, stack);
+ return false;
+ }
+
+ auto leftFieldType = GetType(leftField);
+ auto rightFieldType = GetType(rightField);
+
+ StructuralTypeMatchStack subStack;
+ subStack.parent = stack;
+ subStack.leftDecl = leftField;
+ subStack.rightDecl = rightField;
+
+ if(!validateTypesMatch(context, leftFieldType,rightFieldType, &subStack))
+ return false;
+ }
+ }
+ }
+
+ // Everything seemed to match recursively.
+ return true;
+ }
+ }
+ }
+
+ // If we are looking at `T[N]` and `U[M]` we want to check that
+ // `T` is structurally equivalent to `U` and `N` is the same as `M`.
+ else if( auto leftArrayType = left->As<ArrayExpressionType>() )
+ {
+ if( auto rightArrayType = right->As<ArrayExpressionType>() )
+ {
+ if(!validateTypesMatch(context, leftArrayType->baseType, rightArrayType->baseType, stack) )
+ return false;
+
+ if(!validateValuesMatch(context, leftArrayType->ArrayLength, rightArrayType->ArrayLength, stack))
+ return false;
+
+ return true;
+ }
+ }
+
+ diagnoseTypeMismatch(context, stack);
+ return false;
+}
+
// This function is supposed to determine if two global shader
// parameter declarations represent the same logical parameter
// (so that they should get the exact same binding(s) allocated).
//
static bool doesParameterMatch(
- ParameterBindingContext*,
+ ParameterBindingContext* context,
RefPtr<VarLayout> varLayout,
- ParameterInfo*)
+ ParameterInfo* parameterInfo)
{
// Any "varying" parameter should automatically be excluded
//
@@ -378,9 +711,12 @@ static bool doesParameterMatch(
}
}
- // TODO: this is where we should apply a more detailed
- // matching process, to check that the existing
- // declarations conform to the same basic layout.
+ StructuralTypeMatchStack stack;
+ stack.parent = nullptr;
+ stack.leftDecl = varLayout->varDecl;
+ stack.rightDecl = parameterInfo->varLayouts[0]->varDecl;
+
+ validateTypesMatch(context, varLayout->typeLayout->type, parameterInfo->varLayouts[0]->typeLayout->type, &stack);
return true;
}
@@ -415,14 +751,6 @@ static bool findLayoutArg(
//
-static Name* getReflectionName(VarDeclBase* varDecl)
-{
- if (auto reflectionNameModifier = varDecl->FindModifier<ParameterGroupReflectionName>())
- return reflectionNameModifier->nameAndLoc.name;
-
- return varDecl->getName();
-}
-
static bool isGLSLBuiltinName(VarDeclBase* varDecl)
{
return getText(getReflectionName(varDecl)).StartsWith("gl_");