diff options
Diffstat (limited to 'source/slang/slang-check-shader.cpp')
| -rw-r--r-- | source/slang/slang-check-shader.cpp | 86 |
1 files changed, 64 insertions, 22 deletions
diff --git a/source/slang/slang-check-shader.cpp b/source/slang/slang-check-shader.cpp index 333690b36..c229e8f96 100644 --- a/source/slang/slang-check-shader.cpp +++ b/source/slang/slang-check-shader.cpp @@ -52,7 +52,8 @@ namespace Slang /// Recursively walk `type` and add any existential/interface specialization parameters to `ioSpecializationParams`. static void _collectExistentialSpecializationParamsRec( SpecializationParams& ioSpecializationParams, - Type* type) + Type* type, + SourceLoc loc) { // Whether or not something is an array does not affect // the number of existential slots it introduces. @@ -66,7 +67,8 @@ namespace Slang { _collectExistentialSpecializationParamsRec( ioSpecializationParams, - parameterGroupType->getElementType()); + parameterGroupType->getElementType(), + loc); return; } @@ -81,6 +83,7 @@ namespace Slang // SpecializationParam specializationParam; specializationParam.flavor = SpecializationParam::Flavor::ExistentialType; + specializationParam.loc = loc; specializationParam.object = type; ioSpecializationParams.add(specializationParam); } @@ -112,7 +115,8 @@ namespace Slang { _collectExistentialSpecializationParamsRec( ioSpecializationParams, - GetType(paramDeclRef)); + GetType(paramDeclRef), + paramDeclRef.getLoc()); } @@ -147,6 +151,7 @@ namespace Slang { SpecializationParam param; param.flavor = SpecializationParam::Flavor::GenericType; + param.loc = genericTypeParam->loc; param.object = genericTypeParam; m_genericSpecializationParams.add(param); } @@ -154,6 +159,7 @@ namespace Slang { SpecializationParam param; param.flavor = SpecializationParam::Flavor::GenericValue; + param.loc = genericValParam->loc; param.object = genericValParam; m_genericSpecializationParams.add(param); } @@ -1020,9 +1026,21 @@ static bool doesParameterMatch( // SpecializationParam specializationParam; specializationParam.flavor = SpecializationParam::Flavor::GenericType; + specializationParam.loc = globalGenericParam->loc; specializationParam.object = globalGenericParam; m_specializationParams.add(specializationParam); } + else if( auto globalGenericValueParam = as<GlobalGenericValueParamDecl>(globalDecl) ) + { + // A global generic type parameter declaration introduces + // a suitable specialization parameter. + // + SpecializationParam specializationParam; + specializationParam.flavor = SpecializationParam::Flavor::GenericValue; + specializationParam.loc = globalGenericValueParam->loc; + specializationParam.object = globalGenericValueParam; + m_specializationParams.add(specializationParam); + } else if( auto importDecl = as<ImportDecl>(globalDecl) ) { // An `import` declaration creates a requirement dependency @@ -1425,9 +1443,6 @@ static bool doesParameterMatch( auto& arg = args[ii]; auto& param = m_specializationParams[ii]; - auto argType = arg.val.as<Type>(); - SLANG_ASSERT(argType); - switch( param.flavor ) { case SpecializationParam::Flavor::GenericType: @@ -1435,6 +1450,13 @@ static bool doesParameterMatch( auto genericTypeParamDecl = param.object.as<GlobalGenericParamDecl>(); SLANG_ASSERT(genericTypeParamDecl); + RefPtr<Type> argType = as<Type>(arg.val); + if(!argType) + { + sink->diagnose(param.loc, Diagnostics::expectedTypeForSpecializationArg, genericTypeParamDecl); + argType = getLinkage()->getSessionImpl()->getErrorType(); + } + // TODO: There is a serious flaw to this checking logic if we ever have cases where // the constraints on one `type_param` can depend on another `type_param`, e.g.: // @@ -1520,6 +1542,13 @@ static bool doesParameterMatch( auto interfaceType = param.object.as<Type>(); SLANG_ASSERT(interfaceType); + RefPtr<Type> argType = as<Type>(arg.val); + if(!argType) + { + sink->diagnose(param.loc, Diagnostics::expectedTypeForSpecializationArg, interfaceType); + argType = getLinkage()->getSessionImpl()->getErrorType(); + } + auto witness = visitor.tryGetSubtypeWitness(argType, interfaceType); if (!witness) { @@ -1539,6 +1568,29 @@ static bool doesParameterMatch( } break; + case SpecializationParam::Flavor::GenericValue: + { + auto paramDecl = param.object.as<GlobalGenericValueParamDecl>(); + SLANG_ASSERT(paramDecl); + + // Now we need to check that the argument `Val` has the + // appropriate type expected by the parameter. + + RefPtr<IntVal> intVal = as<IntVal>(arg.val); + if(!intVal) + { + sink->diagnose(param.loc, Diagnostics::expectedValueOfTypeForSpecializationArg, paramDecl->getType(), paramDecl); + intVal = new ConstantIntVal(0); + } + + ModuleSpecializationInfo::GenericArgInfo expandedArg; + expandedArg.paramDecl = paramDecl; + expandedArg.argVal = intVal; + + specializationInfo->genericArgs.add(expandedArg); + } + break; + default: SLANG_UNEXPECTED("unhandled specialization parameter flavor"); } @@ -1556,27 +1608,17 @@ static bool doesParameterMatch( { auto linkage = componentType->getLinkage(); + SharedSemanticsContext semanticsContext(linkage, sink); + SemanticsVisitor semanticsVisitor(&semanticsContext); + auto argCount = argExprs.getCount(); for(Index ii = 0; ii < argCount; ++ii ) { auto argExpr = argExprs[ii]; auto paramInfo = componentType->getSpecializationParam(ii); - // TODO: We should support non-type arguments here - - auto argType = checkProperType(linkage, TypeExp(argExpr), sink); - if( !argType ) - { - // If no witness was found, then we will be unable to satisfy - // the conformances required. - sink->diagnose(argExpr, - Diagnostics::expectedAType, - argExpr->type); - continue; - } - SpecializationArg arg; - arg.val = argType; + arg.val = semanticsVisitor.ExtractGenericArgVal(argExpr); outArgs.add(arg); } } @@ -1757,7 +1799,7 @@ static bool doesParameterMatch( RefPtr<Expr> argExpr; for (auto & s : scopesToTry) { - argExpr = linkage->parseTypeString(name, s); + argExpr = linkage->parseTermString(name, s); argExpr = semantics.CheckTerm(argExpr); if( argExpr ) { @@ -1790,7 +1832,7 @@ static bool doesParameterMatch( SemanticsVisitor visitor(&sharedSemanticsContext); SpecializationParams specializationParams; - _collectExistentialSpecializationParamsRec(specializationParams, unspecializedType); + _collectExistentialSpecializationParamsRec(specializationParams, unspecializedType, SourceLoc()); assert(specializationParams.getCount() == argCount); |
