diff options
Diffstat (limited to 'source/slang')
| -rw-r--r-- | source/slang/diagnostic-defs.h | 12 | ||||
| -rw-r--r-- | source/slang/ir-legalize-types.cpp | 17 | ||||
| -rw-r--r-- | source/slang/legalize-types.cpp | 25 | ||||
| -rw-r--r-- | source/slang/legalize-types.h | 16 | ||||
| -rw-r--r-- | source/slang/parameter-binding.cpp | 354 |
5 files changed, 387 insertions, 37 deletions
diff --git a/source/slang/diagnostic-defs.h b/source/slang/diagnostic-defs.h index f8b09a196..698591d35 100644 --- a/source/slang/diagnostic-defs.h +++ b/source/slang/diagnostic-defs.h @@ -241,6 +241,18 @@ DIAGNOSTIC(39999, Error, invalidFloatingPOintLiteralSuffix, "invalid suffix '$0' DIAGNOSTIC(39999, Error, conflictingExplicitBindingsForParameter, "conflicting explicit bindings for parameter '$0'") DIAGNOSTIC(39999, Warning, parameterBindingsOverlap, "explicit binding for parameter '$0' overlaps with parameter '$1'") + +DIAGNOSTIC(39999, Error, shaderParameterDeclarationsDontMatch, "declarations of shader parameter '$0' in different translation units don't match") + +DIAGNOSTIC(39999, Note, shaderParameterTypeMismatch, "type is declared as '$0' in one translation unit, and '$0' in another") +DIAGNOSTIC(39999, Note, fieldTypeMisMatch, "type of field '$0' is declared as '$1' in one translation unit, and '$2' in another") +DIAGNOSTIC(39999, Note, fieldDeclarationsDontMatch, "type '$0' is declared with different fields in each translation unit") +DIAGNOSTIC(39999, Note, usedInDeclarationOf, "used in declaration of '$0'") + + + + + DIAGNOSTIC(38000, Error, entryPointFunctionNotFound, "no function found matching entry point name '$0'") DIAGNOSTIC(38001, Error, ambiguousEntryPoint, "more than one function matches entry point name '$0'") DIAGNOSTIC(38002, Note, entryPointCandidate, "see candidate declaration for entry point '$0'") diff --git a/source/slang/ir-legalize-types.cpp b/source/slang/ir-legalize-types.cpp index 9234af8f5..4e3bafd31 100644 --- a/source/slang/ir-legalize-types.cpp +++ b/source/slang/ir-legalize-types.cpp @@ -13,6 +13,7 @@ #include "ir.h" #include "ir-insts.h" #include "legalize-types.h" +#include "mangle.h" namespace Slang { @@ -277,7 +278,7 @@ static LegalVal legalizeLoad( for (auto ee : legalPtrVal.getTuple()->elements) { TuplePseudoVal::Element element; - element.fieldDeclRef = ee.fieldDeclRef; + element.mangledName = ee.mangledName; element.val = legalizeLoad(context, ee.val); tupleVal->elements.Add(element); @@ -366,11 +367,13 @@ static LegalVal legalizeFieldAddress( case LegalVal::Flavor::pair: { + String mangledFieldName = getMangledName(fieldDeclRef.getDecl()); + // There are two sides, the ordinary and the special, // and we basically just dispatch to both of them. auto pairVal = legalPtrOperand.getPair(); auto pairInfo = pairVal->pairInfo; - auto pairElement = pairInfo->findElement(fieldDeclRef); + auto pairElement = pairInfo->findElement(mangledFieldName); if (!pairElement) { SLANG_UNEXPECTED("didn't find tuple element"); @@ -424,6 +427,8 @@ static LegalVal legalizeFieldAddress( case LegalVal::Flavor::tuple: { + String mangledFieldName = getMangledName(fieldDeclRef.getDecl()); + // The operand is a tuple of pointer-like // values, we want to extract the element // corresponding to a field. We will handle @@ -432,7 +437,7 @@ static LegalVal legalizeFieldAddress( auto ptrTupleInfo = legalPtrOperand.getTuple(); for (auto ee : ptrTupleInfo->elements) { - if (ee.fieldDeclRef.Equals(fieldDeclRef)) + if (ee.mangledName == mangledFieldName) { return ee.val; } @@ -542,7 +547,7 @@ static LegalVal legalizeGetElementPtr( auto elemType = tupleType->elements[ee].type; TuplePseudoVal::Element resElem; - resElem.fieldDeclRef = ptrElem.fieldDeclRef; + resElem.mangledName = ptrElem.mangledName; resElem.val = legalizeGetElementPtr( context, elemType, @@ -1001,7 +1006,7 @@ static LegalVal declareVars( for (auto ee : tupleType->elements) { - auto fieldLayout = getFieldLayout(typeLayout, ee.fieldDeclRef); + auto fieldLayout = getFieldLayout(typeLayout, ee.mangledName); RefPtr<TypeLayout> fieldTypeLayout = fieldLayout ? fieldLayout->typeLayout : nullptr; // If we are processing layout information, then @@ -1026,7 +1031,7 @@ static LegalVal declareVars( globalNameInfo); TuplePseudoVal::Element element; - element.fieldDeclRef = ee.fieldDeclRef; + element.mangledName = ee.mangledName; element.val = fieldVal; tupleVal->elements.Add(element); } diff --git a/source/slang/legalize-types.cpp b/source/slang/legalize-types.cpp index d0cf2ab69..c90b12558 100644 --- a/source/slang/legalize-types.cpp +++ b/source/slang/legalize-types.cpp @@ -232,10 +232,11 @@ struct TupleTypeBuilder break; } + String mangledFieldName = getMangledName(fieldDeclRef.getDecl()); PairInfo::Element pairElement; pairElement.flags = 0; - pairElement.fieldDeclRef = fieldDeclRef; + pairElement.mangledName = mangledFieldName; pairElement.fieldPairInfo = elementPairInfo; // We will always add a field to the "ordinary" @@ -272,7 +273,7 @@ struct TupleTypeBuilder pairElement.flags |= PairInfo::kFlag_hasSpecial; TuplePseudoType::Element specialElement; - specialElement.fieldDeclRef = fieldDeclRef; + specialElement.mangledName = mangledFieldName; specialElement.type = specialType; specialElements.Add(specialElement); } @@ -557,7 +558,7 @@ static LegalType createLegalUniformBufferType( { TuplePseudoType::Element newElement; - newElement.fieldDeclRef = ee.fieldDeclRef; + newElement.mangledName = ee.mangledName; newElement.type = LegalType::implicitDeref(ee.type); bufferPseudoTupleType->elements.Add(newElement); @@ -657,7 +658,7 @@ static LegalType createLegalPtrType( { TuplePseudoType::Element newElement; - newElement.fieldDeclRef = ee.fieldDeclRef; + newElement.mangledName = ee.mangledName; newElement.type = createLegalPtrType( context, typeDeclRef, @@ -772,7 +773,7 @@ static LegalType wrapLegalType( { TuplePseudoType::Element element; - element.fieldDeclRef = ee.fieldDeclRef; + element.mangledName = ee.mangledName; element.type = wrapLegalType( context, ee.type, @@ -988,8 +989,8 @@ RefPtr<TypeLayout> getDerefTypeLayout( } RefPtr<VarLayout> getFieldLayout( - TypeLayout* typeLayout, - DeclRef<VarDeclBase> fieldDeclRef) + TypeLayout* typeLayout, + String const& mangledFieldName) { if (!typeLayout) return nullptr; @@ -1013,9 +1014,13 @@ RefPtr<VarLayout> getFieldLayout( if (auto structTypeLayout = dynamic_cast<StructTypeLayout*>(typeLayout)) { - RefPtr<VarLayout> fieldLayout; - if (structTypeLayout->mapVarToLayout.TryGetValue(fieldDeclRef.getDecl(), fieldLayout)) - return fieldLayout; + for(auto ff : structTypeLayout->fields) + { + if(mangledFieldName == getMangledName(ff->varDecl) ) + { + return ff; + } + } } return nullptr; diff --git a/source/slang/legalize-types.h b/source/slang/legalize-types.h index 853b9f47f..2dffe1db9 100644 --- a/source/slang/legalize-types.h +++ b/source/slang/legalize-types.h @@ -138,7 +138,7 @@ struct TuplePseudoType : LegalTypeImpl struct Element { // The field that this element replaces - DeclRef<VarDeclBase> fieldDeclRef; + String mangledName; // The legalized type of the element LegalType type; @@ -161,7 +161,7 @@ struct PairInfo : RefObject struct Element { // The original field the element represents - DeclRef<Decl> fieldDeclRef; + String mangledName; // The conceptual type of the field. // If both the `hasOrdinary` and @@ -192,11 +192,11 @@ struct PairInfo : RefObject // which fields are on which side(s). List<Element> elements; - Element* findElement(DeclRef<Decl> const& fieldDeclRef) + Element* findElement(String const& mangledName) { for (auto& ee : elements) { - if(ee.fieldDeclRef.Equals(fieldDeclRef)) + if(ee.mangledName == mangledName) return ⅇ } return nullptr; @@ -227,8 +227,8 @@ RefPtr<TypeLayout> getDerefTypeLayout( TypeLayout* typeLayout); RefPtr<VarLayout> getFieldLayout( - TypeLayout* typeLayout, - DeclRef<VarDeclBase> fieldDeclRef); + TypeLayout* typeLayout, + String const& mangledFieldName); // Represents the "chain" of declarations that // were followed to get to a variable that we @@ -321,8 +321,8 @@ struct TuplePseudoVal : LegalValImpl { struct Element { - DeclRef<VarDeclBase> fieldDeclRef; - LegalVal val; + String mangledName; + LegalVal val; }; List<Element> elements; diff --git a/source/slang/parameter-binding.cpp b/source/slang/parameter-binding.cpp index e1c5c1aca..37642ac81 100644 --- a/source/slang/parameter-binding.cpp +++ b/source/slang/parameter-binding.cpp @@ -351,14 +351,347 @@ LayoutSemanticInfo ExtractLayoutSemanticInfo( return info; } +static Name* getReflectionName(VarDeclBase* varDecl) +{ + if (auto reflectionNameModifier = varDecl->FindModifier<ParameterGroupReflectionName>()) + return reflectionNameModifier->nameAndLoc.name; + + return varDecl->getName(); +} + +// Information tracked when doing a structural +// match of types. +struct StructuralTypeMatchStack +{ + DeclRef<VarDeclBase> leftDecl; + DeclRef<VarDeclBase> rightDecl; + StructuralTypeMatchStack* parent; +}; + +static void diagnoseParameterTypeMismatch( + ParameterBindingContext* context, + StructuralTypeMatchStack* inStack) +{ + assert(inStack); + + // The bottom-most entry in the stack should represent + // the shader parameters that kicked things off + auto stack = inStack; + while(stack->parent) + stack = stack->parent; + + getSink(context)->diagnose(stack->leftDecl, Diagnostics::shaderParameterDeclarationsDontMatch, getReflectionName(stack->leftDecl)); + getSink(context)->diagnose(stack->rightDecl, Diagnostics::seeOtherDeclarationOf, getReflectionName(stack->rightDecl)); +} + +// Two types that were expected to match did not. +// Inform the user with a suitable message. +static void diagnoseTypeMismatch( + ParameterBindingContext* context, + StructuralTypeMatchStack* inStack) +{ + auto stack = inStack; + assert(stack); + diagnoseParameterTypeMismatch(context, stack); + + auto leftType = GetType(stack->leftDecl); + auto rightType = GetType(stack->rightDecl); + + if( stack->parent ) + { + getSink(context)->diagnose(stack->leftDecl, Diagnostics::fieldTypeMisMatch, getReflectionName(stack->leftDecl), leftType, rightType); + getSink(context)->diagnose(stack->rightDecl, Diagnostics::seeOtherDeclarationOf, getReflectionName(stack->rightDecl)); + + stack = stack->parent; + if( stack ) + { + while( stack->parent ) + { + getSink(context)->diagnose(stack->leftDecl, Diagnostics::usedInDeclarationOf, getReflectionName(stack->leftDecl)); + stack = stack->parent; + } + } + } + else + { + getSink(context)->diagnose(stack->leftDecl, Diagnostics::shaderParameterTypeMismatch, leftType, rightType); + } +} + +// Two types that were expected to match did not. +// Inform the user with a suitable message. +static void diagnoseTypeFieldsMismatch( + ParameterBindingContext* context, + DeclRef<Decl> const& left, + DeclRef<Decl> const& right, + StructuralTypeMatchStack* stack) +{ + diagnoseParameterTypeMismatch(context, stack); + + getSink(context)->diagnose(left, Diagnostics::fieldDeclarationsDontMatch, left.GetName()); + getSink(context)->diagnose(right, Diagnostics::seeOtherDeclarationOf, right.GetName()); + + if( stack ) + { + while( stack->parent ) + { + getSink(context)->diagnose(stack->leftDecl, Diagnostics::usedInDeclarationOf, getReflectionName(stack->leftDecl)); + stack = stack->parent; + } + } +} + +static void collectFields( + DeclRef<AggTypeDecl> declRef, + List<DeclRef<StructField>>& outFields) +{ + for( auto fieldDeclRef : getMembersOfType<StructField>(declRef) ) + { + if(fieldDeclRef.getDecl()->HasModifier<HLSLStaticModifier>()) + continue; + + outFields.Add(fieldDeclRef); + } +} + +static bool validateTypesMatch( + ParameterBindingContext* context, + Type* left, + Type* right, + StructuralTypeMatchStack* stack); + +static bool validateIntValuesMatch( + ParameterBindingContext* context, + IntVal* left, + IntVal* right, + StructuralTypeMatchStack* stack) +{ + if(left->EqualsVal(right)) + return true; + + // TODO: are there other cases we need to handle here? + + diagnoseTypeMismatch(context, stack); + return false; +} + + +static bool validateValuesMatch( + ParameterBindingContext* context, + Val* left, + Val* right, + StructuralTypeMatchStack* stack) +{ + if( auto leftType = dynamic_cast<Type*>(left) ) + { + if( auto rightType = dynamic_cast<Type*>(right) ) + { + return validateTypesMatch(context, leftType, rightType, stack); + } + } + + if( auto leftInt = dynamic_cast<IntVal*>(left) ) + { + if( auto rightInt = dynamic_cast<IntVal*>(right) ) + { + return validateIntValuesMatch(context, leftInt, rightInt, stack); + } + } + + if( auto leftWitness = dynamic_cast<SubtypeWitness*>(left) ) + { + if( auto rightWitness = dynamic_cast<SubtypeWitness*>(right) ) + { + return true; + } + } + + diagnoseTypeMismatch(context, stack); + return false; +} + +static bool validateGenericSubstitutionsMatch( + ParameterBindingContext* context, + GenericSubstitution* left, + GenericSubstitution* right, + StructuralTypeMatchStack* stack) +{ + if( !left ) + { + if( !right ) + { + return true; + } + + diagnoseTypeMismatch(context, stack); + return false; + } + + + + UInt argCount = left->args.Count(); + if( argCount != right->args.Count() ) + { + diagnoseTypeMismatch(context, stack); + return false; + } + + for( UInt aa = 0; aa < argCount; ++aa ) + { + auto leftArg = left->args[aa]; + auto rightArg = right->args[aa]; + + if(!validateValuesMatch(context, leftArg, rightArg, stack)) + return false; + } + + return true; +} + +static bool validateSpecializationsMatch( + ParameterBindingContext* context, + SubstitutionSet left, + SubstitutionSet right, + StructuralTypeMatchStack* stack) +{ + if(!validateGenericSubstitutionsMatch( + context, + left.genericSubstitutions, + right.genericSubstitutions, + stack)) + { + return false; + } + + // TODO: anything else to match? + + return true; +} + +// Determine if two types "match" for the purposes of `cbuffer` layout rules. +// +static bool validateTypesMatch( + ParameterBindingContext* context, + Type* left, + Type* right, + StructuralTypeMatchStack* stack) +{ + if(left->Equals(right)) + return true; + + // It is possible that the types don't match exactly, but + // they *do* match structurally. + + // Note: the following code will lead to infinite recursion if there + // are ever recursive types. We'd need a more refined system to + // cache the matches we've already found. + + if( auto leftDeclRefType = left->As<DeclRefType>() ) + { + if( auto rightDeclRefType = right->As<DeclRefType>() ) + { + // Are they references to matching decl refs? + auto leftDeclRef = leftDeclRefType->declRef; + auto rightDeclRef = rightDeclRefType->declRef; + + // Do the reference the same declaration? Or declarations + // with the same name? + // + // TODO: we should only consider the same-name case if the + // declarations come from translation units being compiled + // (and not an imported module). + if( leftDeclRef.getDecl() == rightDeclRef.getDecl() + || leftDeclRef.GetName() == rightDeclRef.GetName() ) + { + // Check that any generic arguments match + if( !validateSpecializationsMatch( + context, + leftDeclRef.substitutions, + rightDeclRef.substitutions, + stack) ) + { + return false; + } + + // Check that any declared fields match too. + if( auto leftStructDeclRef = leftDeclRef.As<AggTypeDecl>() ) + { + if( auto rightStructDeclRef = rightDeclRef.As<AggTypeDecl>() ) + { + List<DeclRef<StructField>> leftFields; + List<DeclRef<StructField>> rightFields; + + collectFields(leftStructDeclRef, leftFields); + collectFields(rightStructDeclRef, rightFields); + + UInt leftFieldCount = leftFields.Count(); + UInt rightFieldCount = rightFields.Count(); + + if( leftFieldCount != rightFieldCount ) + { + diagnoseTypeFieldsMismatch(context, leftDeclRef, rightDeclRef, stack); + return false; + } + + for( UInt ii = 0; ii < leftFieldCount; ++ii ) + { + auto leftField = leftFields[ii]; + auto rightField = rightFields[ii]; + + if( leftField.GetName() != rightField.GetName() ) + { + diagnoseTypeFieldsMismatch(context, leftDeclRef, rightDeclRef, stack); + return false; + } + + auto leftFieldType = GetType(leftField); + auto rightFieldType = GetType(rightField); + + StructuralTypeMatchStack subStack; + subStack.parent = stack; + subStack.leftDecl = leftField; + subStack.rightDecl = rightField; + + if(!validateTypesMatch(context, leftFieldType,rightFieldType, &subStack)) + return false; + } + } + } + + // Everything seemed to match recursively. + return true; + } + } + } + + // If we are looking at `T[N]` and `U[M]` we want to check that + // `T` is structurally equivalent to `U` and `N` is the same as `M`. + else if( auto leftArrayType = left->As<ArrayExpressionType>() ) + { + if( auto rightArrayType = right->As<ArrayExpressionType>() ) + { + if(!validateTypesMatch(context, leftArrayType->baseType, rightArrayType->baseType, stack) ) + return false; + + if(!validateValuesMatch(context, leftArrayType->ArrayLength, rightArrayType->ArrayLength, stack)) + return false; + + return true; + } + } + + diagnoseTypeMismatch(context, stack); + return false; +} + // This function is supposed to determine if two global shader // parameter declarations represent the same logical parameter // (so that they should get the exact same binding(s) allocated). // static bool doesParameterMatch( - ParameterBindingContext*, + ParameterBindingContext* context, RefPtr<VarLayout> varLayout, - ParameterInfo*) + ParameterInfo* parameterInfo) { // Any "varying" parameter should automatically be excluded // @@ -378,9 +711,12 @@ static bool doesParameterMatch( } } - // TODO: this is where we should apply a more detailed - // matching process, to check that the existing - // declarations conform to the same basic layout. + StructuralTypeMatchStack stack; + stack.parent = nullptr; + stack.leftDecl = varLayout->varDecl; + stack.rightDecl = parameterInfo->varLayouts[0]->varDecl; + + validateTypesMatch(context, varLayout->typeLayout->type, parameterInfo->varLayouts[0]->typeLayout->type, &stack); return true; } @@ -415,14 +751,6 @@ static bool findLayoutArg( // -static Name* getReflectionName(VarDeclBase* varDecl) -{ - if (auto reflectionNameModifier = varDecl->FindModifier<ParameterGroupReflectionName>()) - return reflectionNameModifier->nameAndLoc.name; - - return varDecl->getName(); -} - static bool isGLSLBuiltinName(VarDeclBase* varDecl) { return getText(getReflectionName(varDecl)).StartsWith("gl_"); |
