summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
Diffstat (limited to 'source')
-rw-r--r--source/slang/diagnostic-defs.h12
-rw-r--r--source/slang/ir-legalize-types.cpp17
-rw-r--r--source/slang/legalize-types.cpp25
-rw-r--r--source/slang/legalize-types.h16
-rw-r--r--source/slang/parameter-binding.cpp354
5 files changed, 387 insertions, 37 deletions
diff --git a/source/slang/diagnostic-defs.h b/source/slang/diagnostic-defs.h
index f8b09a196..698591d35 100644
--- a/source/slang/diagnostic-defs.h
+++ b/source/slang/diagnostic-defs.h
@@ -241,6 +241,18 @@ DIAGNOSTIC(39999, Error, invalidFloatingPOintLiteralSuffix, "invalid suffix '$0'
DIAGNOSTIC(39999, Error, conflictingExplicitBindingsForParameter, "conflicting explicit bindings for parameter '$0'")
DIAGNOSTIC(39999, Warning, parameterBindingsOverlap, "explicit binding for parameter '$0' overlaps with parameter '$1'")
+
+DIAGNOSTIC(39999, Error, shaderParameterDeclarationsDontMatch, "declarations of shader parameter '$0' in different translation units don't match")
+
+DIAGNOSTIC(39999, Note, shaderParameterTypeMismatch, "type is declared as '$0' in one translation unit, and '$0' in another")
+DIAGNOSTIC(39999, Note, fieldTypeMisMatch, "type of field '$0' is declared as '$1' in one translation unit, and '$2' in another")
+DIAGNOSTIC(39999, Note, fieldDeclarationsDontMatch, "type '$0' is declared with different fields in each translation unit")
+DIAGNOSTIC(39999, Note, usedInDeclarationOf, "used in declaration of '$0'")
+
+
+
+
+
DIAGNOSTIC(38000, Error, entryPointFunctionNotFound, "no function found matching entry point name '$0'")
DIAGNOSTIC(38001, Error, ambiguousEntryPoint, "more than one function matches entry point name '$0'")
DIAGNOSTIC(38002, Note, entryPointCandidate, "see candidate declaration for entry point '$0'")
diff --git a/source/slang/ir-legalize-types.cpp b/source/slang/ir-legalize-types.cpp
index 9234af8f5..4e3bafd31 100644
--- a/source/slang/ir-legalize-types.cpp
+++ b/source/slang/ir-legalize-types.cpp
@@ -13,6 +13,7 @@
#include "ir.h"
#include "ir-insts.h"
#include "legalize-types.h"
+#include "mangle.h"
namespace Slang
{
@@ -277,7 +278,7 @@ static LegalVal legalizeLoad(
for (auto ee : legalPtrVal.getTuple()->elements)
{
TuplePseudoVal::Element element;
- element.fieldDeclRef = ee.fieldDeclRef;
+ element.mangledName = ee.mangledName;
element.val = legalizeLoad(context, ee.val);
tupleVal->elements.Add(element);
@@ -366,11 +367,13 @@ static LegalVal legalizeFieldAddress(
case LegalVal::Flavor::pair:
{
+ String mangledFieldName = getMangledName(fieldDeclRef.getDecl());
+
// There are two sides, the ordinary and the special,
// and we basically just dispatch to both of them.
auto pairVal = legalPtrOperand.getPair();
auto pairInfo = pairVal->pairInfo;
- auto pairElement = pairInfo->findElement(fieldDeclRef);
+ auto pairElement = pairInfo->findElement(mangledFieldName);
if (!pairElement)
{
SLANG_UNEXPECTED("didn't find tuple element");
@@ -424,6 +427,8 @@ static LegalVal legalizeFieldAddress(
case LegalVal::Flavor::tuple:
{
+ String mangledFieldName = getMangledName(fieldDeclRef.getDecl());
+
// The operand is a tuple of pointer-like
// values, we want to extract the element
// corresponding to a field. We will handle
@@ -432,7 +437,7 @@ static LegalVal legalizeFieldAddress(
auto ptrTupleInfo = legalPtrOperand.getTuple();
for (auto ee : ptrTupleInfo->elements)
{
- if (ee.fieldDeclRef.Equals(fieldDeclRef))
+ if (ee.mangledName == mangledFieldName)
{
return ee.val;
}
@@ -542,7 +547,7 @@ static LegalVal legalizeGetElementPtr(
auto elemType = tupleType->elements[ee].type;
TuplePseudoVal::Element resElem;
- resElem.fieldDeclRef = ptrElem.fieldDeclRef;
+ resElem.mangledName = ptrElem.mangledName;
resElem.val = legalizeGetElementPtr(
context,
elemType,
@@ -1001,7 +1006,7 @@ static LegalVal declareVars(
for (auto ee : tupleType->elements)
{
- auto fieldLayout = getFieldLayout(typeLayout, ee.fieldDeclRef);
+ auto fieldLayout = getFieldLayout(typeLayout, ee.mangledName);
RefPtr<TypeLayout> fieldTypeLayout = fieldLayout ? fieldLayout->typeLayout : nullptr;
// If we are processing layout information, then
@@ -1026,7 +1031,7 @@ static LegalVal declareVars(
globalNameInfo);
TuplePseudoVal::Element element;
- element.fieldDeclRef = ee.fieldDeclRef;
+ element.mangledName = ee.mangledName;
element.val = fieldVal;
tupleVal->elements.Add(element);
}
diff --git a/source/slang/legalize-types.cpp b/source/slang/legalize-types.cpp
index d0cf2ab69..c90b12558 100644
--- a/source/slang/legalize-types.cpp
+++ b/source/slang/legalize-types.cpp
@@ -232,10 +232,11 @@ struct TupleTypeBuilder
break;
}
+ String mangledFieldName = getMangledName(fieldDeclRef.getDecl());
PairInfo::Element pairElement;
pairElement.flags = 0;
- pairElement.fieldDeclRef = fieldDeclRef;
+ pairElement.mangledName = mangledFieldName;
pairElement.fieldPairInfo = elementPairInfo;
// We will always add a field to the "ordinary"
@@ -272,7 +273,7 @@ struct TupleTypeBuilder
pairElement.flags |= PairInfo::kFlag_hasSpecial;
TuplePseudoType::Element specialElement;
- specialElement.fieldDeclRef = fieldDeclRef;
+ specialElement.mangledName = mangledFieldName;
specialElement.type = specialType;
specialElements.Add(specialElement);
}
@@ -557,7 +558,7 @@ static LegalType createLegalUniformBufferType(
{
TuplePseudoType::Element newElement;
- newElement.fieldDeclRef = ee.fieldDeclRef;
+ newElement.mangledName = ee.mangledName;
newElement.type = LegalType::implicitDeref(ee.type);
bufferPseudoTupleType->elements.Add(newElement);
@@ -657,7 +658,7 @@ static LegalType createLegalPtrType(
{
TuplePseudoType::Element newElement;
- newElement.fieldDeclRef = ee.fieldDeclRef;
+ newElement.mangledName = ee.mangledName;
newElement.type = createLegalPtrType(
context,
typeDeclRef,
@@ -772,7 +773,7 @@ static LegalType wrapLegalType(
{
TuplePseudoType::Element element;
- element.fieldDeclRef = ee.fieldDeclRef;
+ element.mangledName = ee.mangledName;
element.type = wrapLegalType(
context,
ee.type,
@@ -988,8 +989,8 @@ RefPtr<TypeLayout> getDerefTypeLayout(
}
RefPtr<VarLayout> getFieldLayout(
- TypeLayout* typeLayout,
- DeclRef<VarDeclBase> fieldDeclRef)
+ TypeLayout* typeLayout,
+ String const& mangledFieldName)
{
if (!typeLayout)
return nullptr;
@@ -1013,9 +1014,13 @@ RefPtr<VarLayout> getFieldLayout(
if (auto structTypeLayout = dynamic_cast<StructTypeLayout*>(typeLayout))
{
- RefPtr<VarLayout> fieldLayout;
- if (structTypeLayout->mapVarToLayout.TryGetValue(fieldDeclRef.getDecl(), fieldLayout))
- return fieldLayout;
+ for(auto ff : structTypeLayout->fields)
+ {
+ if(mangledFieldName == getMangledName(ff->varDecl) )
+ {
+ return ff;
+ }
+ }
}
return nullptr;
diff --git a/source/slang/legalize-types.h b/source/slang/legalize-types.h
index 853b9f47f..2dffe1db9 100644
--- a/source/slang/legalize-types.h
+++ b/source/slang/legalize-types.h
@@ -138,7 +138,7 @@ struct TuplePseudoType : LegalTypeImpl
struct Element
{
// The field that this element replaces
- DeclRef<VarDeclBase> fieldDeclRef;
+ String mangledName;
// The legalized type of the element
LegalType type;
@@ -161,7 +161,7 @@ struct PairInfo : RefObject
struct Element
{
// The original field the element represents
- DeclRef<Decl> fieldDeclRef;
+ String mangledName;
// The conceptual type of the field.
// If both the `hasOrdinary` and
@@ -192,11 +192,11 @@ struct PairInfo : RefObject
// which fields are on which side(s).
List<Element> elements;
- Element* findElement(DeclRef<Decl> const& fieldDeclRef)
+ Element* findElement(String const& mangledName)
{
for (auto& ee : elements)
{
- if(ee.fieldDeclRef.Equals(fieldDeclRef))
+ if(ee.mangledName == mangledName)
return &ee;
}
return nullptr;
@@ -227,8 +227,8 @@ RefPtr<TypeLayout> getDerefTypeLayout(
TypeLayout* typeLayout);
RefPtr<VarLayout> getFieldLayout(
- TypeLayout* typeLayout,
- DeclRef<VarDeclBase> fieldDeclRef);
+ TypeLayout* typeLayout,
+ String const& mangledFieldName);
// Represents the "chain" of declarations that
// were followed to get to a variable that we
@@ -321,8 +321,8 @@ struct TuplePseudoVal : LegalValImpl
{
struct Element
{
- DeclRef<VarDeclBase> fieldDeclRef;
- LegalVal val;
+ String mangledName;
+ LegalVal val;
};
List<Element> elements;
diff --git a/source/slang/parameter-binding.cpp b/source/slang/parameter-binding.cpp
index e1c5c1aca..37642ac81 100644
--- a/source/slang/parameter-binding.cpp
+++ b/source/slang/parameter-binding.cpp
@@ -351,14 +351,347 @@ LayoutSemanticInfo ExtractLayoutSemanticInfo(
return info;
}
+static Name* getReflectionName(VarDeclBase* varDecl)
+{
+ if (auto reflectionNameModifier = varDecl->FindModifier<ParameterGroupReflectionName>())
+ return reflectionNameModifier->nameAndLoc.name;
+
+ return varDecl->getName();
+}
+
+// Information tracked when doing a structural
+// match of types.
+struct StructuralTypeMatchStack
+{
+ DeclRef<VarDeclBase> leftDecl;
+ DeclRef<VarDeclBase> rightDecl;
+ StructuralTypeMatchStack* parent;
+};
+
+static void diagnoseParameterTypeMismatch(
+ ParameterBindingContext* context,
+ StructuralTypeMatchStack* inStack)
+{
+ assert(inStack);
+
+ // The bottom-most entry in the stack should represent
+ // the shader parameters that kicked things off
+ auto stack = inStack;
+ while(stack->parent)
+ stack = stack->parent;
+
+ getSink(context)->diagnose(stack->leftDecl, Diagnostics::shaderParameterDeclarationsDontMatch, getReflectionName(stack->leftDecl));
+ getSink(context)->diagnose(stack->rightDecl, Diagnostics::seeOtherDeclarationOf, getReflectionName(stack->rightDecl));
+}
+
+// Two types that were expected to match did not.
+// Inform the user with a suitable message.
+static void diagnoseTypeMismatch(
+ ParameterBindingContext* context,
+ StructuralTypeMatchStack* inStack)
+{
+ auto stack = inStack;
+ assert(stack);
+ diagnoseParameterTypeMismatch(context, stack);
+
+ auto leftType = GetType(stack->leftDecl);
+ auto rightType = GetType(stack->rightDecl);
+
+ if( stack->parent )
+ {
+ getSink(context)->diagnose(stack->leftDecl, Diagnostics::fieldTypeMisMatch, getReflectionName(stack->leftDecl), leftType, rightType);
+ getSink(context)->diagnose(stack->rightDecl, Diagnostics::seeOtherDeclarationOf, getReflectionName(stack->rightDecl));
+
+ stack = stack->parent;
+ if( stack )
+ {
+ while( stack->parent )
+ {
+ getSink(context)->diagnose(stack->leftDecl, Diagnostics::usedInDeclarationOf, getReflectionName(stack->leftDecl));
+ stack = stack->parent;
+ }
+ }
+ }
+ else
+ {
+ getSink(context)->diagnose(stack->leftDecl, Diagnostics::shaderParameterTypeMismatch, leftType, rightType);
+ }
+}
+
+// Two types that were expected to match did not.
+// Inform the user with a suitable message.
+static void diagnoseTypeFieldsMismatch(
+ ParameterBindingContext* context,
+ DeclRef<Decl> const& left,
+ DeclRef<Decl> const& right,
+ StructuralTypeMatchStack* stack)
+{
+ diagnoseParameterTypeMismatch(context, stack);
+
+ getSink(context)->diagnose(left, Diagnostics::fieldDeclarationsDontMatch, left.GetName());
+ getSink(context)->diagnose(right, Diagnostics::seeOtherDeclarationOf, right.GetName());
+
+ if( stack )
+ {
+ while( stack->parent )
+ {
+ getSink(context)->diagnose(stack->leftDecl, Diagnostics::usedInDeclarationOf, getReflectionName(stack->leftDecl));
+ stack = stack->parent;
+ }
+ }
+}
+
+static void collectFields(
+ DeclRef<AggTypeDecl> declRef,
+ List<DeclRef<StructField>>& outFields)
+{
+ for( auto fieldDeclRef : getMembersOfType<StructField>(declRef) )
+ {
+ if(fieldDeclRef.getDecl()->HasModifier<HLSLStaticModifier>())
+ continue;
+
+ outFields.Add(fieldDeclRef);
+ }
+}
+
+static bool validateTypesMatch(
+ ParameterBindingContext* context,
+ Type* left,
+ Type* right,
+ StructuralTypeMatchStack* stack);
+
+static bool validateIntValuesMatch(
+ ParameterBindingContext* context,
+ IntVal* left,
+ IntVal* right,
+ StructuralTypeMatchStack* stack)
+{
+ if(left->EqualsVal(right))
+ return true;
+
+ // TODO: are there other cases we need to handle here?
+
+ diagnoseTypeMismatch(context, stack);
+ return false;
+}
+
+
+static bool validateValuesMatch(
+ ParameterBindingContext* context,
+ Val* left,
+ Val* right,
+ StructuralTypeMatchStack* stack)
+{
+ if( auto leftType = dynamic_cast<Type*>(left) )
+ {
+ if( auto rightType = dynamic_cast<Type*>(right) )
+ {
+ return validateTypesMatch(context, leftType, rightType, stack);
+ }
+ }
+
+ if( auto leftInt = dynamic_cast<IntVal*>(left) )
+ {
+ if( auto rightInt = dynamic_cast<IntVal*>(right) )
+ {
+ return validateIntValuesMatch(context, leftInt, rightInt, stack);
+ }
+ }
+
+ if( auto leftWitness = dynamic_cast<SubtypeWitness*>(left) )
+ {
+ if( auto rightWitness = dynamic_cast<SubtypeWitness*>(right) )
+ {
+ return true;
+ }
+ }
+
+ diagnoseTypeMismatch(context, stack);
+ return false;
+}
+
+static bool validateGenericSubstitutionsMatch(
+ ParameterBindingContext* context,
+ GenericSubstitution* left,
+ GenericSubstitution* right,
+ StructuralTypeMatchStack* stack)
+{
+ if( !left )
+ {
+ if( !right )
+ {
+ return true;
+ }
+
+ diagnoseTypeMismatch(context, stack);
+ return false;
+ }
+
+
+
+ UInt argCount = left->args.Count();
+ if( argCount != right->args.Count() )
+ {
+ diagnoseTypeMismatch(context, stack);
+ return false;
+ }
+
+ for( UInt aa = 0; aa < argCount; ++aa )
+ {
+ auto leftArg = left->args[aa];
+ auto rightArg = right->args[aa];
+
+ if(!validateValuesMatch(context, leftArg, rightArg, stack))
+ return false;
+ }
+
+ return true;
+}
+
+static bool validateSpecializationsMatch(
+ ParameterBindingContext* context,
+ SubstitutionSet left,
+ SubstitutionSet right,
+ StructuralTypeMatchStack* stack)
+{
+ if(!validateGenericSubstitutionsMatch(
+ context,
+ left.genericSubstitutions,
+ right.genericSubstitutions,
+ stack))
+ {
+ return false;
+ }
+
+ // TODO: anything else to match?
+
+ return true;
+}
+
+// Determine if two types "match" for the purposes of `cbuffer` layout rules.
+//
+static bool validateTypesMatch(
+ ParameterBindingContext* context,
+ Type* left,
+ Type* right,
+ StructuralTypeMatchStack* stack)
+{
+ if(left->Equals(right))
+ return true;
+
+ // It is possible that the types don't match exactly, but
+ // they *do* match structurally.
+
+ // Note: the following code will lead to infinite recursion if there
+ // are ever recursive types. We'd need a more refined system to
+ // cache the matches we've already found.
+
+ if( auto leftDeclRefType = left->As<DeclRefType>() )
+ {
+ if( auto rightDeclRefType = right->As<DeclRefType>() )
+ {
+ // Are they references to matching decl refs?
+ auto leftDeclRef = leftDeclRefType->declRef;
+ auto rightDeclRef = rightDeclRefType->declRef;
+
+ // Do the reference the same declaration? Or declarations
+ // with the same name?
+ //
+ // TODO: we should only consider the same-name case if the
+ // declarations come from translation units being compiled
+ // (and not an imported module).
+ if( leftDeclRef.getDecl() == rightDeclRef.getDecl()
+ || leftDeclRef.GetName() == rightDeclRef.GetName() )
+ {
+ // Check that any generic arguments match
+ if( !validateSpecializationsMatch(
+ context,
+ leftDeclRef.substitutions,
+ rightDeclRef.substitutions,
+ stack) )
+ {
+ return false;
+ }
+
+ // Check that any declared fields match too.
+ if( auto leftStructDeclRef = leftDeclRef.As<AggTypeDecl>() )
+ {
+ if( auto rightStructDeclRef = rightDeclRef.As<AggTypeDecl>() )
+ {
+ List<DeclRef<StructField>> leftFields;
+ List<DeclRef<StructField>> rightFields;
+
+ collectFields(leftStructDeclRef, leftFields);
+ collectFields(rightStructDeclRef, rightFields);
+
+ UInt leftFieldCount = leftFields.Count();
+ UInt rightFieldCount = rightFields.Count();
+
+ if( leftFieldCount != rightFieldCount )
+ {
+ diagnoseTypeFieldsMismatch(context, leftDeclRef, rightDeclRef, stack);
+ return false;
+ }
+
+ for( UInt ii = 0; ii < leftFieldCount; ++ii )
+ {
+ auto leftField = leftFields[ii];
+ auto rightField = rightFields[ii];
+
+ if( leftField.GetName() != rightField.GetName() )
+ {
+ diagnoseTypeFieldsMismatch(context, leftDeclRef, rightDeclRef, stack);
+ return false;
+ }
+
+ auto leftFieldType = GetType(leftField);
+ auto rightFieldType = GetType(rightField);
+
+ StructuralTypeMatchStack subStack;
+ subStack.parent = stack;
+ subStack.leftDecl = leftField;
+ subStack.rightDecl = rightField;
+
+ if(!validateTypesMatch(context, leftFieldType,rightFieldType, &subStack))
+ return false;
+ }
+ }
+ }
+
+ // Everything seemed to match recursively.
+ return true;
+ }
+ }
+ }
+
+ // If we are looking at `T[N]` and `U[M]` we want to check that
+ // `T` is structurally equivalent to `U` and `N` is the same as `M`.
+ else if( auto leftArrayType = left->As<ArrayExpressionType>() )
+ {
+ if( auto rightArrayType = right->As<ArrayExpressionType>() )
+ {
+ if(!validateTypesMatch(context, leftArrayType->baseType, rightArrayType->baseType, stack) )
+ return false;
+
+ if(!validateValuesMatch(context, leftArrayType->ArrayLength, rightArrayType->ArrayLength, stack))
+ return false;
+
+ return true;
+ }
+ }
+
+ diagnoseTypeMismatch(context, stack);
+ return false;
+}
+
// This function is supposed to determine if two global shader
// parameter declarations represent the same logical parameter
// (so that they should get the exact same binding(s) allocated).
//
static bool doesParameterMatch(
- ParameterBindingContext*,
+ ParameterBindingContext* context,
RefPtr<VarLayout> varLayout,
- ParameterInfo*)
+ ParameterInfo* parameterInfo)
{
// Any "varying" parameter should automatically be excluded
//
@@ -378,9 +711,12 @@ static bool doesParameterMatch(
}
}
- // TODO: this is where we should apply a more detailed
- // matching process, to check that the existing
- // declarations conform to the same basic layout.
+ StructuralTypeMatchStack stack;
+ stack.parent = nullptr;
+ stack.leftDecl = varLayout->varDecl;
+ stack.rightDecl = parameterInfo->varLayouts[0]->varDecl;
+
+ validateTypesMatch(context, varLayout->typeLayout->type, parameterInfo->varLayouts[0]->typeLayout->type, &stack);
return true;
}
@@ -415,14 +751,6 @@ static bool findLayoutArg(
//
-static Name* getReflectionName(VarDeclBase* varDecl)
-{
- if (auto reflectionNameModifier = varDecl->FindModifier<ParameterGroupReflectionName>())
- return reflectionNameModifier->nameAndLoc.name;
-
- return varDecl->getName();
-}
-
static bool isGLSLBuiltinName(VarDeclBase* varDecl)
{
return getText(getReflectionName(varDecl)).StartsWith("gl_");