summaryrefslogtreecommitdiffstats
path: root/source/slang/check.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'source/slang/check.cpp')
-rw-r--r--source/slang/check.cpp515
1 files changed, 504 insertions, 11 deletions
diff --git a/source/slang/check.cpp b/source/slang/check.cpp
index a9f84c5c3..998324612 100644
--- a/source/slang/check.cpp
+++ b/source/slang/check.cpp
@@ -9306,7 +9306,7 @@ namespace Slang
}
/// Recursively walk `paramDeclRef` and add any required existential slots to `ioSlots`.
- static void _collectExistentialParamsRec(
+ static void _collectExistentialSlotsRec(
ExistentialSlots& ioSlots,
DeclRef<VarDeclBase> paramDeclRef)
{
@@ -9339,7 +9339,7 @@ namespace Slang
if(fieldDeclRef.getDecl()->HasModifier<HLSLStaticModifier>())
continue;
- _collectExistentialParamsRec(ioSlots, fieldDeclRef);
+ _collectExistentialSlotsRec(ioSlots, fieldDeclRef);
}
}
}
@@ -9349,11 +9349,26 @@ namespace Slang
// element types.
}
+ /// Add information about a shader parameter to `ioParams` and `ioSlots`
+ static void _collectExistentialSlotsForShaderParam(
+ ShaderParamInfo& ioParamInfo,
+ ExistentialSlots& ioSlots,
+ DeclRef<VarDeclBase> paramDeclRef)
+ {
+ UInt startSlot = ioSlots.types.Count();
+ _collectExistentialSlotsRec(ioSlots, paramDeclRef);
+ UInt endSlot = ioSlots.types.Count();
+ UInt slotCount = endSlot - startSlot;
+
+ ioParamInfo.firstExistentialTypeSlot = startSlot;
+ ioParamInfo.existentialTypeSlotCount = slotCount;
+ }
+
/// Enumerate the existential-type parameters of an `EntryPoint`.
///
/// Any parameters found will be added to the list of existential slots on `this`.
///
- void EntryPoint::_collectExistentialParams()
+ void EntryPoint::_collectShaderParams()
{
// Note: we defensively test whether there is a function decl-ref
// because this routine gets called from the constructor, and
@@ -9363,7 +9378,15 @@ namespace Slang
{
for( auto paramDeclRef : GetParameters(funcDeclRef) )
{
- _collectExistentialParamsRec(m_existentialSlots, paramDeclRef);
+ ShaderParamInfo shaderParamInfo;
+ shaderParamInfo.paramDeclRef = paramDeclRef;
+
+ _collectExistentialSlotsForShaderParam(
+ shaderParamInfo,
+ m_existentialSlots,
+ paramDeclRef);
+
+ m_shaderParams.Add(shaderParamInfo);
}
}
}
@@ -9644,16 +9667,433 @@ namespace Slang
return entryPoint;
}
+ /// Get the name a variable will use for reflection purposes
+Name* getReflectionName(VarDeclBase* varDecl)
+{
+ if (auto reflectionNameModifier = varDecl->FindModifier<ParameterGroupReflectionName>())
+ return reflectionNameModifier->nameAndLoc.name;
+
+ return varDecl->getName();
+}
+
+// Information tracked when doing a structural
+// match of types.
+struct StructuralTypeMatchStack
+{
+ DeclRef<VarDeclBase> leftDecl;
+ DeclRef<VarDeclBase> rightDecl;
+ StructuralTypeMatchStack* parent;
+};
+
+static void diagnoseParameterTypeMismatch(
+ DiagnosticSink* sink,
+ StructuralTypeMatchStack* inStack)
+{
+ SLANG_ASSERT(inStack);
+
+ // The bottom-most entry in the stack should represent
+ // the shader parameters that kicked things off
+ auto stack = inStack;
+ while(stack->parent)
+ stack = stack->parent;
+
+ sink->diagnose(stack->leftDecl, Diagnostics::shaderParameterDeclarationsDontMatch, getReflectionName(stack->leftDecl));
+ sink->diagnose(stack->rightDecl, Diagnostics::seeOtherDeclarationOf, getReflectionName(stack->rightDecl));
+}
+
+// Two types that were expected to match did not.
+// Inform the user with a suitable message.
+static void diagnoseTypeMismatch(
+ DiagnosticSink* sink,
+ StructuralTypeMatchStack* inStack)
+{
+ auto stack = inStack;
+ SLANG_ASSERT(stack);
+ diagnoseParameterTypeMismatch(sink, stack);
+
+ auto leftType = GetType(stack->leftDecl);
+ auto rightType = GetType(stack->rightDecl);
+
+ if( stack->parent )
+ {
+ sink->diagnose(stack->leftDecl, Diagnostics::fieldTypeMisMatch, getReflectionName(stack->leftDecl), leftType, rightType);
+ sink->diagnose(stack->rightDecl, Diagnostics::seeOtherDeclarationOf, getReflectionName(stack->rightDecl));
+
+ stack = stack->parent;
+ if( stack )
+ {
+ while( stack->parent )
+ {
+ sink->diagnose(stack->leftDecl, Diagnostics::usedInDeclarationOf, getReflectionName(stack->leftDecl));
+ stack = stack->parent;
+ }
+ }
+ }
+ else
+ {
+ sink->diagnose(stack->leftDecl, Diagnostics::shaderParameterTypeMismatch, leftType, rightType);
+ }
+}
+
+// Two types that were expected to match did not.
+// Inform the user with a suitable message.
+static void diagnoseTypeFieldsMismatch(
+ DiagnosticSink* sink,
+ DeclRef<Decl> const& left,
+ DeclRef<Decl> const& right,
+ StructuralTypeMatchStack* stack)
+{
+ diagnoseParameterTypeMismatch(sink, stack);
+
+ sink->diagnose(left, Diagnostics::fieldDeclarationsDontMatch, left.GetName());
+ sink->diagnose(right, Diagnostics::seeOtherDeclarationOf, right.GetName());
+
+ if( stack )
+ {
+ while( stack->parent )
+ {
+ sink->diagnose(stack->leftDecl, Diagnostics::usedInDeclarationOf, getReflectionName(stack->leftDecl));
+ stack = stack->parent;
+ }
+ }
+}
+
+static void collectFields(
+ DeclRef<AggTypeDecl> declRef,
+ List<DeclRef<VarDecl>>& outFields)
+{
+ for( auto fieldDeclRef : getMembersOfType<VarDecl>(declRef) )
+ {
+ if(fieldDeclRef.getDecl()->HasModifier<HLSLStaticModifier>())
+ continue;
+
+ outFields.Add(fieldDeclRef);
+ }
+}
+
+static bool validateTypesMatch(
+ DiagnosticSink* sink,
+ Type* left,
+ Type* right,
+ StructuralTypeMatchStack* stack);
+
+static bool validateIntValuesMatch(
+ DiagnosticSink* sink,
+ IntVal* left,
+ IntVal* right,
+ StructuralTypeMatchStack* stack)
+{
+ if(left->EqualsVal(right))
+ return true;
+
+ // TODO: are there other cases we need to handle here?
+
+ diagnoseTypeMismatch(sink, stack);
+ return false;
+}
+
+
+static bool validateValuesMatch(
+ DiagnosticSink* sink,
+ Val* left,
+ Val* right,
+ StructuralTypeMatchStack* stack)
+{
+ if( auto leftType = dynamicCast<Type>(left) )
+ {
+ if( auto rightType = dynamicCast<Type>(right) )
+ {
+ return validateTypesMatch(sink, leftType, rightType, stack);
+ }
+ }
+
+ if( auto leftInt = dynamicCast<IntVal>(left) )
+ {
+ if( auto rightInt = dynamicCast<IntVal>(right) )
+ {
+ return validateIntValuesMatch(sink, leftInt, rightInt, stack);
+ }
+ }
+
+ if( auto leftWitness = dynamicCast<SubtypeWitness>(left) )
+ {
+ if( auto rightWitness = dynamicCast<SubtypeWitness>(right) )
+ {
+ return true;
+ }
+ }
+
+ diagnoseTypeMismatch(sink, stack);
+ return false;
+}
+
+static bool validateGenericSubstitutionsMatch(
+ DiagnosticSink* sink,
+ GenericSubstitution* left,
+ GenericSubstitution* right,
+ StructuralTypeMatchStack* stack)
+{
+ if( !left )
+ {
+ if( !right )
+ {
+ return true;
+ }
+
+ diagnoseTypeMismatch(sink, stack);
+ return false;
+ }
+
+
+
+ UInt argCount = left->args.Count();
+ if( argCount != right->args.Count() )
+ {
+ diagnoseTypeMismatch(sink, stack);
+ return false;
+ }
+
+ for( UInt aa = 0; aa < argCount; ++aa )
+ {
+ auto leftArg = left->args[aa];
+ auto rightArg = right->args[aa];
+
+ if(!validateValuesMatch(sink, leftArg, rightArg, stack))
+ return false;
+ }
+
+ return true;
+}
+
+static bool validateThisTypeSubstitutionsMatch(
+ DiagnosticSink* /*sink*/,
+ ThisTypeSubstitution* /*left*/,
+ ThisTypeSubstitution* /*right*/,
+ StructuralTypeMatchStack* /*stack*/)
+{
+ // TODO: actual checking.
+ return true;
+}
+
+static bool validateSpecializationsMatch(
+ DiagnosticSink* sink,
+ SubstitutionSet left,
+ SubstitutionSet right,
+ StructuralTypeMatchStack* stack)
+{
+ auto ll = left.substitutions;
+ auto rr = right.substitutions;
+ for(;;)
+ {
+ // Skip any global generic substitutions.
+ if(auto leftGlobalGeneric = as<GlobalGenericParamSubstitution>(ll))
+ {
+ ll = leftGlobalGeneric->outer;
+ continue;
+ }
+ if(auto rightGlobalGeneric = as<GlobalGenericParamSubstitution>(rr))
+ {
+ rr = rightGlobalGeneric->outer;
+ continue;
+ }
+
+ // If either ran out, then we expect both to have run out.
+ if(!ll || !rr)
+ return !ll && !rr;
+
+ auto leftSubst = ll;
+ auto rightSubst = rr;
+
+ ll = ll->outer;
+ rr = rr->outer;
+
+ if(auto leftGeneric = as<GenericSubstitution>(leftSubst))
+ {
+ if(auto rightGeneric = as<GenericSubstitution>(rightSubst))
+ {
+ if(validateGenericSubstitutionsMatch(sink, leftGeneric, rightGeneric, stack))
+ {
+ continue;
+ }
+ }
+ }
+ else if(auto leftThisType = as<ThisTypeSubstitution>(leftSubst))
+ {
+ if(auto rightThisType = as<ThisTypeSubstitution>(rightSubst))
+ {
+ if(validateThisTypeSubstitutionsMatch(sink, leftThisType, rightThisType, stack))
+ {
+ continue;
+ }
+ }
+ }
+
+ return false;
+ }
+
+ return true;
+}
+
+// Determine if two types "match" for the purposes of `cbuffer` layout rules.
+//
+static bool validateTypesMatch(
+ DiagnosticSink* sink,
+ Type* left,
+ Type* right,
+ StructuralTypeMatchStack* stack)
+{
+ if(left->Equals(right))
+ return true;
+
+ // It is possible that the types don't match exactly, but
+ // they *do* match structurally.
+
+ // Note: the following code will lead to infinite recursion if there
+ // are ever recursive types. We'd need a more refined system to
+ // cache the matches we've already found.
+
+ if( auto leftDeclRefType = as<DeclRefType>(left) )
+ {
+ if( auto rightDeclRefType = as<DeclRefType>(right) )
+ {
+ // Are they references to matching decl refs?
+ auto leftDeclRef = leftDeclRefType->declRef;
+ auto rightDeclRef = rightDeclRefType->declRef;
+
+ // Do the reference the same declaration? Or declarations
+ // with the same name?
+ //
+ // TODO: we should only consider the same-name case if the
+ // declarations come from translation units being compiled
+ // (and not an imported module).
+ if( leftDeclRef.getDecl() == rightDeclRef.getDecl()
+ || leftDeclRef.GetName() == rightDeclRef.GetName() )
+ {
+ // Check that any generic arguments match
+ if( !validateSpecializationsMatch(
+ sink,
+ leftDeclRef.substitutions,
+ rightDeclRef.substitutions,
+ stack) )
+ {
+ return false;
+ }
+
+ // Check that any declared fields match too.
+ if( auto leftStructDeclRef = leftDeclRef.as<AggTypeDecl>() )
+ {
+ if( auto rightStructDeclRef = rightDeclRef.as<AggTypeDecl>() )
+ {
+ List<DeclRef<VarDecl>> leftFields;
+ List<DeclRef<VarDecl>> rightFields;
+
+ collectFields(leftStructDeclRef, leftFields);
+ collectFields(rightStructDeclRef, rightFields);
+
+ UInt leftFieldCount = leftFields.Count();
+ UInt rightFieldCount = rightFields.Count();
+
+ if( leftFieldCount != rightFieldCount )
+ {
+ diagnoseTypeFieldsMismatch(sink, leftDeclRef, rightDeclRef, stack);
+ return false;
+ }
+
+ for( UInt ii = 0; ii < leftFieldCount; ++ii )
+ {
+ auto leftField = leftFields[ii];
+ auto rightField = rightFields[ii];
+
+ if( leftField.GetName() != rightField.GetName() )
+ {
+ diagnoseTypeFieldsMismatch(sink, leftDeclRef, rightDeclRef, stack);
+ return false;
+ }
+
+ auto leftFieldType = GetType(leftField);
+ auto rightFieldType = GetType(rightField);
+
+ StructuralTypeMatchStack subStack;
+ subStack.parent = stack;
+ subStack.leftDecl = leftField;
+ subStack.rightDecl = rightField;
+
+ if(!validateTypesMatch(sink, leftFieldType,rightFieldType, &subStack))
+ return false;
+ }
+ }
+ }
+
+ // Everything seemed to match recursively.
+ return true;
+ }
+ }
+ }
+
+ // If we are looking at `T[N]` and `U[M]` we want to check that
+ // `T` is structurally equivalent to `U` and `N` is the same as `M`.
+ else if( auto leftArrayType = as<ArrayExpressionType>(left) )
+ {
+ if( auto rightArrayType = as<ArrayExpressionType>(right) )
+ {
+ if(!validateTypesMatch(sink, leftArrayType->baseType, rightArrayType->baseType, stack) )
+ return false;
+
+ if(!validateValuesMatch(sink, leftArrayType->ArrayLength, rightArrayType->ArrayLength, stack))
+ return false;
+
+ return true;
+ }
+ }
+
+ diagnoseTypeMismatch(sink, stack);
+ return false;
+}
+
+// This function is supposed to determine if two global shader
+// parameter declarations represent the same logical parameter
+// (so that they should get the exact same binding(s) allocated).
+//
+static bool doesParameterMatch(
+ DiagnosticSink* sink,
+ DeclRef<VarDeclBase> varDeclRef,
+ DeclRef<VarDeclBase> existingVarDeclRef)
+{
+ StructuralTypeMatchStack stack;
+ stack.parent = nullptr;
+ stack.leftDecl = varDeclRef;
+ stack.rightDecl = existingVarDeclRef;
+
+ validateTypesMatch(sink, GetType(varDeclRef), GetType(existingVarDeclRef), &stack);
+
+ return true;
+}
+
+
+
+
/// Enumerate the existential-type parameters of a `Program`.
///
/// Any parameters found will be added to the list of existential slots on `this`.
///
- void Program::_collectExistentialParams()
+ void Program::_collectShaderParams(DiagnosticSink* sink)
{
- // We need to inspect all of the global shader parameters
+ // We need to collect all of the global shader parameters
// referenced by the compile request, and for each we
- // need to determine what existential types parameters it implies.
+ // need to do a few things:
+ //
+ // * We need to determine if the parameter is a duplicate/redeclaration
+ // of the "same" parameter in another translation unit, and collapse
+ // those into one logical shader parameter if so.
+ //
+ // * We need to determine what existential type slots are introduced
+ // by the parameter, and associate that information with the parameter.
+ //
+ // To deal with the first issue, we will maintain a map from a parameter
+ // name to the index of an existing parameter with that name.
//
+ Dictionary<Name*, Int> mapNameToParamIndex;
+
for( auto module : getModuleDependencies() )
{
auto moduleDecl = module->getModuleDecl();
@@ -9662,9 +10102,55 @@ namespace Slang
if(!isGlobalShaderParameter(globalVar))
continue;
- _collectExistentialParamsRec(
+ // This declaration may represent the same logical parameter
+ // as a declaration that came from a different translation unit.
+ // If that is the case, we want to re-use the same `ShaderParamInfo`
+ // across both parameters.
+ //
+ // TODO: This logic currently detects *any* global-scope parameters
+ // with matching names, but it should eventually be narrowly
+ // scoped so that it only applies to parameters from unnamed modules
+ // (that is, modules that represent directly-compiled shader files
+ // and not `import`ed code).
+ //
+ // First we look for an existing entry matching the name
+ // of this parameter:
+ //
+ auto paramName = getReflectionName(globalVar);
+ Int existingParamIndex = -1;
+ if( mapNameToParamIndex.TryGetValue(paramName, existingParamIndex) )
+ {
+ // If the parameters have the same name, but don't "match" according to some reasonable rules,
+ // then we will treat them as distinct global parameters.
+ //
+ // Note: all of the mismatch cases currently report errors, so that
+ // compilation will fail on a mismatch.
+ //
+ auto& existingParam = m_shaderParams[existingParamIndex];
+ if( doesParameterMatch(sink, makeDeclRef(globalVar.Ptr()), existingParam.paramDeclRef) )
+ {
+ // If we hit this case, then we had a match, and we should
+ // consider the new variable to be a redclaration of
+ // the existing one.
+
+ existingParam.additionalParamDeclRefs.Add(
+ makeDeclRef(globalVar.Ptr()));
+ continue;
+ }
+ }
+
+ Int newParamIndex = Int(m_shaderParams.Count());
+ mapNameToParamIndex.Add(paramName, newParamIndex);
+
+ GlobalShaderParamInfo shaderParamInfo;
+ shaderParamInfo.paramDeclRef = makeDeclRef(globalVar.Ptr());
+
+ _collectExistentialSlotsForShaderParam(
+ shaderParamInfo,
m_globalExistentialSlots,
makeDeclRef(globalVar.Ptr()));
+
+ m_shaderParams.Add(shaderParamInfo);
}
}
}
@@ -9795,7 +10281,7 @@ namespace Slang
}
}
- program->_collectExistentialParams();
+ program->_collectShaderParams(sink);
return program;
}
@@ -10181,8 +10667,15 @@ namespace Slang
specializedProgram->setGlobalGenericSubsitution(globalGenericSubsts);
- // Now deal with the existential arguments
- specializedProgram->_collectExistentialParams();
+ // Now deal with the shader parameters and existential arguments
+ //
+ // Note: We should in theory be able to just copy over the shader
+ // parameters and existential slot information from the unspecialized
+ // program. This could save some time, but it would also mean that
+ // the only way to create a specialized program is by creating an
+ // unspecialized on first, which is maybe not always desirable.
+ //
+ specializedProgram->_collectShaderParams(sink);
specializedProgram->_specializeExistentialSlots(globalExistentialArgs, sink);
return specializedProgram;