diff options
Diffstat (limited to 'source')
| -rw-r--r-- | source/core/list.h | 8 | ||||
| -rw-r--r-- | source/slang/check.cpp | 102 | ||||
| -rw-r--r-- | source/slang/compiler.cpp | 6 | ||||
| -rw-r--r-- | source/slang/compiler.h | 12 | ||||
| -rw-r--r-- | source/slang/decl-defs.h | 5 | ||||
| -rw-r--r-- | source/slang/diagnostic-defs.h | 9 | ||||
| -rw-r--r-- | source/slang/emit.cpp | 11 | ||||
| -rw-r--r-- | source/slang/emit.h | 5 | ||||
| -rw-r--r-- | source/slang/ir-insts.h | 4 | ||||
| -rw-r--r-- | source/slang/ir.cpp | 119 | ||||
| -rw-r--r-- | source/slang/lookup.cpp | 7 | ||||
| -rw-r--r-- | source/slang/lower-to-ir.cpp | 4 | ||||
| -rw-r--r-- | source/slang/lower.cpp | 7 | ||||
| -rw-r--r-- | source/slang/parameter-binding.cpp | 239 | ||||
| -rw-r--r-- | source/slang/parser.cpp | 103 | ||||
| -rw-r--r-- | source/slang/reflection.cpp | 19 | ||||
| -rw-r--r-- | source/slang/slang.cpp | 32 | ||||
| -rw-r--r-- | source/slang/syntax-base-defs.h | 32 | ||||
| -rw-r--r-- | source/slang/syntax.cpp | 62 | ||||
| -rw-r--r-- | source/slang/syntax.h | 2 | ||||
| -rw-r--r-- | source/slang/type-layout.cpp | 32 | ||||
| -rw-r--r-- | source/slang/type-layout.h | 20 |
22 files changed, 745 insertions, 95 deletions
diff --git a/source/core/list.h b/source/core/list.h index af32a39ef..b1461a260 100644 --- a/source/core/list.h +++ b/source/core/list.h @@ -487,7 +487,7 @@ namespace Slang if (predicate(buffer[i])) return i; } - return -1; + return (UInt)-1; } template<typename Func> @@ -498,7 +498,7 @@ namespace Slang if (predicate(buffer[i-1])) return i-1; } - return -1; + return (UInt)-1; } template<typename T2> @@ -509,7 +509,7 @@ namespace Slang if (buffer[i] == val) return i; } - return -1; + return (UInt)-1; } template<typename T2> @@ -520,7 +520,7 @@ namespace Slang if(buffer[i-1] == val) return i-1; } - return -1; + return (UInt)-1; } void Sort() diff --git a/source/slang/check.cpp b/source/slang/check.cpp index 233a82eef..4b8f4f4c1 100644 --- a/source/slang/check.cpp +++ b/source/slang/check.cpp @@ -148,7 +148,6 @@ namespace Slang return expr->type->As<DeclRefType>(); } - RefPtr<Expr> ConstructDeclRefExpr( DeclRef<Decl> declRef, RefPtr<Expr> baseExpr, @@ -1998,6 +1997,22 @@ namespace Slang decl->SetCheckState(DeclCheckState::Checked); } + void visitGlobalGenericParamDecl(GlobalGenericParamDecl * decl) + { + if (decl->IsChecked(DeclCheckState::Checked)) return; + decl->SetCheckState(DeclCheckState::CheckedHeader); + // global generic param only allowed in global scope + auto program = decl->ParentDecl->As<ModuleDecl>(); + if (!program) + getSink()->diagnose(decl, Slang::Diagnostics::globalGenParamInGlobalScopeOnly); + // Now check all of the member declarations. + for (auto member : decl->Members) + { + checkDecl(member); + } + decl->SetCheckState(DeclCheckState::Checked); + } + void visitAssocTypeDecl(AssocTypeDecl* decl) { if (decl->IsChecked(DeclCheckState::Checked)) return; @@ -3703,6 +3718,19 @@ namespace Slang return true; } } + // if an inheritance decl is not found, try to find a GenericTypeConstraintDecl + for (auto genConstraintDeclRef : getMembersOfType<GenericTypeConstraintDecl>(aggTypeDeclRef)) + { + EnsureDecl(genConstraintDeclRef.getDecl()); + auto inheritedType = GetSup(genConstraintDeclRef); + TypeWitnessBreadcrumb breadcrumb; + breadcrumb.prev = inBreadcrumbs; + breadcrumb.declRef = genConstraintDeclRef; + if (doesTypeConformToInterfaceImpl(originalType, inheritedType, interfaceDeclRef, outWitness, &breadcrumb)) + { + return true; + } + } } else if( auto genericTypeParamDeclRef = declRef.As<GenericTypeParamDecl>() ) { @@ -6582,6 +6610,78 @@ namespace Slang // that we don't have to re-do this effort again later. entryPoint->decl = entryPointFuncDecl; + // Lookup generic parameter types in global scope + for (auto name : entryPoint->genericParameterTypeNames) + { + if (!translationUnitSyntax->memberDictionary.TryGetValue(name, firstDeclWithName)) + { + // If there doesn't appear to be any such declaration, then + // we need to diagnose it as an error, and then bail out. + sink->diagnose(translationUnitSyntax, Diagnostics::entryPointTypeParameterNotFound, name); + return; + } + RefPtr<Type> type; + if (auto aggType = firstDeclWithName->As<AggTypeDecl>()) + { + type = DeclRefType::Create(entryPoint->compileRequest->mSession, DeclRef<Decl>(aggType, nullptr)); + } + else if (auto typeDefDecl = firstDeclWithName->As<TypeDefDecl>()) + { + type = GetType(DeclRef<TypeDefDecl>(typeDefDecl, nullptr)); + } + else + { + sink->diagnose(firstDeclWithName, Diagnostics::entryPointTypeSymbolNotAType, name); + return; + } + entryPoint->genericParameterTypes.Add(type); + } + // check that user-provioded type arguments conforms to the generic type + // parameter declaration of this translation unit + + // collect global generic parameters from all imported modules + List<RefPtr<GlobalGenericParamDecl>> globalGenericParams; + // add current translation unit first + { + auto globalGenParams = translationUnit->SyntaxNode->getMembersOfType<GlobalGenericParamDecl>(); + for (auto p : globalGenParams) + globalGenericParams.Add(p); + } + // add imported modules + for (auto moduleDecl : entryPoint->compileRequest->loadedModulesList) + { + auto globalGenParams = moduleDecl->getMembersOfType<GlobalGenericParamDecl>(); + for (auto p : globalGenParams) + globalGenericParams.Add(p); + } + if (globalGenericParams.Count() != entryPoint->genericParameterTypes.Count()) + { + sink->diagnose(entryPoint->decl, Diagnostics::mismatchEntryPointTypeArgument, globalGenericParams.Count(), + entryPoint->genericParameterTypes.Count()); + return; + } + // if number of entry-point type arguments matches parameters, try find + // SubtypeWitness for each argument + int index = 0; + for (auto & gParam : globalGenericParams) + { + for (auto constraint : gParam->getMembersOfType<GenericTypeConstraintDecl>()) + { + auto interfaceType = GetSup(DeclRef<GenericTypeConstraintDecl>(constraint, nullptr)); + SemanticsVisitor visitor(sink, entryPoint->compileRequest, translationUnit); + auto witness = visitor.tryGetSubtypeWitness(entryPoint->genericParameterTypes[index], interfaceType); + if (!witness) + { + sink->diagnose(gParam, + Diagnostics::typeArgumentDoesNotConformToInterface, gParam->nameAndLoc.name, entryPoint->genericParameterTypes[index], + interfaceType); + } + entryPoint->genericParameterWitnesses.Add(witness); + } + index++; + } + if (sink->errorCount != 0) + return; // TODO: after all that work, we are now in a position to start // validating the declaration itself. E.g., we should check if // the declared input/output parameters have suitable semantics, diff --git a/source/slang/compiler.cpp b/source/slang/compiler.cpp index 302b5704f..acbf51e2e 100644 --- a/source/slang/compiler.cpp +++ b/source/slang/compiler.cpp @@ -11,7 +11,7 @@ #include "parser.h" #include "preprocessor.h" #include "syntax-visitors.h" - +#include "type-layout.h" #include "reflection.h" #include "emit.h" @@ -160,7 +160,7 @@ namespace Slang entryPoint, targetReq->layout.Ptr(), CodeGenTarget::HLSL, - targetReq->target); + targetReq); } } @@ -207,7 +207,7 @@ namespace Slang entryPoint, targetReq->layout.Ptr(), CodeGenTarget::GLSL, - targetReq->target); + targetReq); } } diff --git a/source/slang/compiler.h b/source/slang/compiler.h index f42f36c1f..303be6624 100644 --- a/source/slang/compiler.h +++ b/source/slang/compiler.h @@ -100,6 +100,10 @@ namespace Slang // The name of the entry point function (e.g., `main`) Name* name; + + // The type names we want to substitute into the + // global generic type parameters + List<Name*> genericParameterTypeNames; // The profile that the entry point will be compiled for // (this is a combination of the target state, and also @@ -123,6 +127,11 @@ namespace Slang // it should not be assumed to be available in cases // where any errors were diagnosed. RefPtr<FuncDecl> decl; + + // The declaration of the global generic parameter types + // This will be filled in as part of semantic analysis. + List<RefPtr<Type>> genericParameterTypes; + List<RefPtr<Val>> genericParameterWitnesses; }; enum class PassThroughMode : SlangPassThrough @@ -319,7 +328,8 @@ namespace Slang int addEntryPoint( int translationUnitIndex, String const& name, - Profile profile); + Profile profile, + List<String> const & genericTypeNames); UInt addTarget( CodeGenTarget target); diff --git a/source/slang/decl-defs.h b/source/slang/decl-defs.h index 9c010d156..e24a535c5 100644 --- a/source/slang/decl-defs.h +++ b/source/slang/decl-defs.h @@ -126,6 +126,11 @@ END_SYNTAX_CLASS() SYNTAX_CLASS(AssocTypeDecl, AggTypeDecl) END_SYNTAX_CLASS() +// A '__generic_param' declaration, which defines a generic +// entry-point parameter. Is a container of GenericTypeConstraintDecl +SYNTAX_CLASS(GlobalGenericParamDecl, AggTypeDecl) +END_SYNTAX_CLASS() + // A scope for local declarations (e.g., as part of a statement) SIMPLE_SYNTAX_CLASS(ScopeDecl, ContainerDecl) diff --git a/source/slang/diagnostic-defs.h b/source/slang/diagnostic-defs.h index 7f27e43e8..24e8bc713 100644 --- a/source/slang/diagnostic-defs.h +++ b/source/slang/diagnostic-defs.h @@ -196,7 +196,7 @@ DIAGNOSTIC(33070, Error, expectedFunction, "expression preceding parenthesis of // 303xx: interfaces and associated types DIAGNOSTIC(30300, Error, assocTypeInInterfaceOnly, "'associatedtype' can only be defined in an 'interface'.") - +DIAGNOSTIC(30301, Error, globalGenParamInGlobalScopeOnly, "'__generic_param' can only be defined global scope.") // TODO: need to assign numbers to all these extra diagnostics... DIAGNOSTIC(39999, Error, expectedIntegerConstantWrongType, "expected integer constant (found: '$0')") @@ -244,11 +244,17 @@ DIAGNOSTIC(38001, Error, ambiguousEntryPoint, "more than one function matches en DIAGNOSTIC(38002, Note, entryPointCandidate, "see candidate declaration for entry point '$0'") DIAGNOSTIC(38003, Error, entryPointSymbolNotAFunction, "entry point '$0' must be declared as a function") +DIAGNOSTIC(38004, Error, entryPointTypeParameterNotFound, "no type found matching entry-point type parameter name '$0'") +DIAGNOSTIC(38005, Error, entryPointTypeSymbolNotAType, "entry-point type parameter '$0' must be declared as a type") + DIAGNOSTIC(38100, Error, typeDoesntImplementInterfaceRequirement, "type '$0' does not provide required interface member '$1'") DIAGNOSTIC(38101, Error, thisExpressionOutsideOfTypeDecl, "'this' expression can only be used in members of an aggregate type") DIAGNOSTIC(38102, Error, initializerNotInsideType, "an 'init' declaration is only allowed inside a type or 'extension' declaration") DIAGNOSTIC(38102, Error, accessorMustBeInsideSubscriptOrProperty, "an accessor declaration is only allowed inside a subscript or property declaration") +DIAGNOSTIC(38020, Error, mismatchEntryPointTypeArgument, "expecting $0 entry-point type arguments, provided $1.") +DIAGNOSTIC(38021, Error, typeArgumentDoesNotConformToInterface, "type argument `$1` for generic parameter `$0` does not conform to interface `$1`.") + // // 4xxxx - IL code generation. // @@ -264,7 +270,6 @@ DIAGNOSTIC(49999, Error, unknownSystemValueSemantic, "unknown system-value seman // // 5xxxx - Target code generation. // - DIAGNOSTIC(50020, Error, unknownStageType, "Unknown stage type '$0'.") DIAGNOSTIC(50020, Error, invalidTessCoordType, "TessCoord must have vec2 or vec3 type.") DIAGNOSTIC(50020, Error, invalidFragCoordType, "FragCoord must be a vec4.") diff --git a/source/slang/emit.cpp b/source/slang/emit.cpp index 5b7a42ad7..614e8f474 100644 --- a/source/slang/emit.cpp +++ b/source/slang/emit.cpp @@ -3481,9 +3481,9 @@ struct EmitVisitor break; case LayoutResourceKind::RegisterSpace: + case LayoutResourceKind::GenericResource: // ignore break; - default: { Emit(": register("); @@ -6771,7 +6771,7 @@ EntryPointLayout* findEntryPointLayout( StructTypeLayout* getGlobalStructLayout( ProgramLayout* programLayout) { - auto globalScopeLayout = programLayout->globalScopeLayout; + auto globalScopeLayout = programLayout->globalScopeLayout->typeLayout; if( auto gs = globalScopeLayout.As<StructTypeLayout>() ) { return gs.Ptr(); @@ -6816,13 +6816,13 @@ String emitEntryPoint( EntryPointRequest* entryPoint, ProgramLayout* programLayout, CodeGenTarget target, - CodeGenTarget finalTarget) + TargetRequest* targetRequest) { auto translationUnit = entryPoint->getTranslationUnit(); SharedEmitContext sharedContext; sharedContext.target = target; - sharedContext.finalTarget = finalTarget; + sharedContext.finalTarget = targetRequest->target; sharedContext.entryPoint = entryPoint; if (entryPoint) @@ -6890,7 +6890,8 @@ String emitEntryPoint( auto lowered = specializeIRForEntryPoint( entryPoint, programLayout, - target); + target, + targetRequest); // If the user specified the flag that they want us to dump // IR, then do it here, for the target-specific, but diff --git a/source/slang/emit.h b/source/slang/emit.h index e17a84d5a..98845f9c6 100644 --- a/source/slang/emit.h +++ b/source/slang/emit.h @@ -26,8 +26,7 @@ namespace Slang // The target language to generate code in (e.g., HLSL/GLSL) CodeGenTarget target, - // The "final" target that we are being asked to compile for - // (e.g., SPIR-V, DXBC, ...). - CodeGenTarget finalTarget); + // The full target request + TargetRequest* targetRequest); } #endif diff --git a/source/slang/ir-insts.h b/source/slang/ir-insts.h index 52acf6576..a91143a43 100644 --- a/source/slang/ir-insts.h +++ b/source/slang/ir-insts.h @@ -419,6 +419,7 @@ struct IRBuilder IRFunc* createFunc(); IRGlobalVar* createGlobalVar( IRType* valueType); + IRWitnessTable* createWitnessTable(Dictionary<DeclRef<Decl>, Decl*> & witnesses); IRWitnessTable* createWitnessTable(); IRWitnessTableEntry* createWitnessTableEntry( IRWitnessTable* witnessTable, @@ -565,7 +566,8 @@ struct IRBuilder IRModule* specializeIRForEntryPoint( EntryPointRequest* entryPointRequest, ProgramLayout* programLayout, - CodeGenTarget target); + CodeGenTarget target, + TargetRequest* targetReq); // Find suitable uses of the `specialize` instruction that // can be replaced with references to specialized functions. diff --git a/source/slang/ir.cpp b/source/slang/ir.cpp index 9068e717b..bfc26643c 100644 --- a/source/slang/ir.cpp +++ b/source/slang/ir.cpp @@ -3089,12 +3089,16 @@ namespace Slang // to the layout to use for it. Dictionary<String, VarLayout*> globalVarLayouts; + RefPtr<GlobalGenericParamSubstitution> subst; + // Override the "maybe clone" logic so that we always clone virtual IRValue* maybeCloneValue(IRValue* originalVal) override; // Override teh "maybe clone" logic so that we carefully // clone any IR proxy values inside substitutions virtual DeclRef<Decl> maybeCloneDeclRef(DeclRef<Decl> const& declRef) override; + + virtual RefPtr<Type> maybeCloneType(Type* originalType) override; }; @@ -3102,6 +3106,11 @@ namespace Slang IRFunc* cloneFunc(IRSpecContext* context, IRFunc* originalFunc); IRWitnessTable* cloneWitnessTable(IRSpecContext* context, IRWitnessTable* originalVar); + RefPtr<Type> IRSpecContext::maybeCloneType(Type* originalType) + { + return originalType->Substitute(subst).As<Type>(); + } + IRValue* IRSpecContext::maybeCloneValue(IRValue* originalValue) { switch (originalValue->op) @@ -3143,6 +3152,33 @@ namespace Slang case kIROp_decl_ref: { IRDeclRef* od = (IRDeclRef*)originalValue; + + // if the declRef is one of the __generic_param decl being substituted by subst + // return the substituted decl + if (subst) + { + if (od->declRef.getDecl() == subst->paramDecl) + return builder->getTypeVal(subst->actualType.As<Type>()); + else if (auto genConstraint = od->declRef.As<GenericTypeConstraintDecl>()) + { + // a decl-ref to GenericTypeConstraintDecl as a result of + // referencing a generic parameter type should be replaced with + // the actual witness table + if (genConstraint.getDecl()->ParentDecl == subst->paramDecl) + { + // find the witness table from subst + for (auto witness : subst->witnessTables) + { + if (witness.Key->EqualsVal(GetSup(genConstraint))) + { + auto proxyVal = witness.Value.As<IRProxyVal>(); + SLANG_ASSERT(proxyVal); + return proxyVal->inst; + } + } + } + } + } auto declRef = maybeCloneDeclRef(od->declRef); return builder->getDeclRefVal(declRef); } @@ -3150,7 +3186,9 @@ namespace Slang case kIROp_TypeType: { IRValue* od = (IRValue*)originalValue; - return builder->getTypeVal(od->type); + int ioDiff = 0; + auto newType = od->type->SubstituteImpl(subst, &ioDiff); + return builder->getTypeVal(newType.As<Type>()); } break; default: @@ -3207,7 +3245,9 @@ namespace Slang newSubst->outer = cloneSubstitutions(context, subst->outer); return newSubst; } - return nullptr; + else + SLANG_UNREACHABLE("unimplemented cloneSubstitution"); + UNREACHABLE_RETURN(nullptr); } DeclRef<Decl> IRSpecContext::maybeCloneDeclRef(DeclRef<Decl> const& declRef) @@ -3281,7 +3321,7 @@ namespace Slang IRGlobalVar* cloneGlobalVar(IRSpecContext* context, IRGlobalVar* originalVar) { - auto clonedVar = context->builder->createGlobalVar(originalVar->getType()->getValueType()); + auto clonedVar = context->builder->createGlobalVar(context->maybeCloneType(originalVar->getType()->getValueType())); registerClonedValue(context, clonedVar, originalVar); auto mangledName = originalVar->mangledName; @@ -3703,10 +3743,67 @@ namespace Slang } } + // implementation provided in parameter-binding.cpp + RefPtr<ProgramLayout> specializeProgramLayout( + TargetRequest * targetReq, + ProgramLayout* programLayout, + Substitutions * typeSubst); + + RefPtr<GlobalGenericParamSubstitution> createGlobalGenericParamSubstitution( + EntryPointRequest * entryPointRequest, + ProgramLayout * programLayout, + IRSpecContext* context, + IRModule* originalIRModule) + { + RefPtr<GlobalGenericParamSubstitution> globalParamSubst; + Substitutions * curTailSubst = nullptr; + for (auto param : programLayout->globalGenericParams) + { + auto paramSubst = new GlobalGenericParamSubstitution(); + if (!globalParamSubst) + globalParamSubst = paramSubst; + if (curTailSubst) + curTailSubst->outer = paramSubst; + curTailSubst = paramSubst; + paramSubst->paramDecl = param->decl; + SLANG_ASSERT((UInt)param->index < entryPointRequest->genericParameterTypes.Count()); + paramSubst->actualType = entryPointRequest->genericParameterTypes[param->index]; + // find witness tables + for (auto witness : entryPointRequest->genericParameterWitnesses) + { + if (auto subtypeWitness = witness.As<SubtypeWitness>()) + { + if (subtypeWitness->sub->EqualsVal(paramSubst->actualType)) + { + auto witnessTableName = getMangledNameForConformanceWitness(subtypeWitness->sub, subtypeWitness->sup); + auto globalVar = originalIRModule->getFirstGlobalValue(); + IRGlobalValue * table = nullptr; + while (globalVar) + { + if (globalVar->mangledName == witnessTableName) + { + table = globalVar; + break; + } + globalVar = globalVar->getNextValue(); + } + SLANG_ASSERT(table); + table = cloneWitnessTable(context, (IRWitnessTable*)(table)); + IRProxyVal * tableVal = new IRProxyVal(); + tableVal->inst = table; + paramSubst->witnessTables.Add(KeyValuePair<RefPtr<Type>, RefPtr<Val>>(subtypeWitness->sup, tableVal)); + } + } + } + } + return globalParamSubst; + } + IRModule* specializeIRForEntryPoint( EntryPointRequest* entryPointRequest, ProgramLayout* programLayout, - CodeGenTarget target) + CodeGenTarget target, + TargetRequest* targetReq) { auto compileRequest = entryPointRequest->compileRequest; auto session = compileRequest->mSession; @@ -3720,8 +3817,6 @@ namespace Slang return nullptr; } - auto entryPointLayout = findEntryPointLayout(programLayout, entryPointRequest); - // We now need to start cloning IR symbols from `originalIRModule` // into a fresh IR module for this entry point. Along the way we need to: // @@ -3746,11 +3841,21 @@ namespace Slang context->builder = &sharedContextStorage.builderStorage; context->target = target; + // Create the GlobalGenericParamSubstitution for substituting global generic types + // into user-provided type arguments + auto globalParamSubst = createGlobalGenericParamSubstitution(entryPointRequest, programLayout, context, originalIRModule); + + context->subst = globalParamSubst; + + // now specailize the program layout using the substitution + RefPtr<ProgramLayout> newProgramLayout = specializeProgramLayout(targetReq, programLayout, globalParamSubst); + + auto entryPointLayout = findEntryPointLayout(newProgramLayout, entryPointRequest); // Next, we want to optimize lookup for layout infromation // associated with global declarations, so that we can // look things up based on the IR values (using mangled names) - auto globalStructLayout = getGlobalStructLayout(programLayout); + auto globalStructLayout = getGlobalStructLayout(newProgramLayout); for (auto globalVarLayout : globalStructLayout->fields) { String mangledName = getMangledName(globalVarLayout->varDecl); diff --git a/source/slang/lookup.cpp b/source/slang/lookup.cpp index b01732362..86bef3f4d 100644 --- a/source/slang/lookup.cpp +++ b/source/slang/lookup.cpp @@ -410,9 +410,9 @@ void lookUpMemberImpl( if (auto declRefType = type->As<DeclRefType>()) { auto declRef = declRefType->declRef; - if (auto assocTypeDeclRef = declRef.As<AssocTypeDecl>()) + if (declRef.As<AssocTypeDecl>() || declRef.As<GlobalGenericParamDecl>()) { - for (auto constraintDeclRef : getMembersOfType<GenericTypeConstraintDecl>(assocTypeDeclRef)) + for (auto constraintDeclRef : getMembersOfType<GenericTypeConstraintDecl>(declRef.As<ContainerDecl>())) { // The super-type in the constraint (e.g., `Foo` in `T : Foo`) // will tell us a type we should use for lookup. @@ -488,5 +488,4 @@ LookupResult lookUpMember( return result; } - -} +}
\ No newline at end of file diff --git a/source/slang/lower-to-ir.cpp b/source/slang/lower-to-ir.cpp index 326d25649..0f3e85805 100644 --- a/source/slang/lower-to-ir.cpp +++ b/source/slang/lower-to-ir.cpp @@ -3538,7 +3538,9 @@ static void lowerEntryPointToIR( // the entry point request. return; } - + // we need to lower all global type arguments as well + for (auto arg : entryPointRequest->genericParameterTypes) + lowerType(context, arg); auto loweredEntryPointFunc = lowerDecl(context, entryPointFuncDecl); } diff --git a/source/slang/lower.cpp b/source/slang/lower.cpp index b375fa80e..5a6603add 100644 --- a/source/slang/lower.cpp +++ b/source/slang/lower.cpp @@ -2870,6 +2870,13 @@ struct LoweringVisitor UNREACHABLE_RETURN(LoweredDecl()); } + LoweredDecl visitGlobalGenericParamDecl(GlobalGenericParamDecl * /*decl*/) + { + // not supported + SLANG_UNREACHABLE("visitGlobalGenericParamDecl in LowerVisitor"); + UNREACHABLE_RETURN(LoweredDecl()); + } + LoweredDecl visitTypeDefDecl(TypeDefDecl* decl) { if (shared->target == CodeGenTarget::GLSL) diff --git a/source/slang/parameter-binding.cpp b/source/slang/parameter-binding.cpp index fa015186b..836ed254f 100644 --- a/source/slang/parameter-binding.cpp +++ b/source/slang/parameter-binding.cpp @@ -667,6 +667,17 @@ static void collectGlobalScopeGLSLVaryingParameter( } // Collect a single declaration into our set of parameters +static void collectGlobalGenericParameter( + ParameterBindingContext* context, + RefPtr<GlobalGenericParamDecl> paramDecl) +{ + RefPtr<GenericParamLayout> layout = new GenericParamLayout(); + layout->decl = paramDecl; + layout->index = (int)context->shared->programLayout->globalGenericParams.Count(); + context->shared->programLayout->globalGenericParams.Add(layout); +} + +// Collect a single declaration into our set of parameters static void collectGlobalScopeParameter( ParameterBindingContext* context, RefPtr<VarDeclBase> varDecl) @@ -1037,7 +1048,13 @@ static void completeBindingsForParameter( continue; } - + else if (kind == LayoutResourceKind::GenericResource) + { + bindingInfo.space = 0; + bindingInfo.count = 0; + bindingInfo.index = 0; + continue; + } // For now we only auto-generate bindings in space zero // @@ -1065,6 +1082,11 @@ static void completeBindingsForParameter( bindingInfo.space = space; } + if (firstTypeLayout->FindResourceInfo(LayoutResourceKind::GenericResource)) + { + + } + // At this point we should have explicit binding locations chosen for // all the relevant resource kinds, so we can apply these to the // declarations: @@ -1093,15 +1115,22 @@ static void collectGlobalScopeParameters( ModuleDecl* program) { // First enumerate parameters at global scope - for( auto decl : program->Members ) + // We collect two things here: + // 1. A shader parameter, which is always a variable + // 2. A global entry-point generic parameter type (`__generic_param`), + // which is a GlobalGenericParamDecl + // We collect global generic type parameters in the first pass, + // So we can fill in the correct index into ordinary type layouts + // for generic types in the second pass. + for (auto decl : program->Members) { - // A shader parameter is always a variable, - // so skip declarations that aren't variables. - auto varDecl = decl.As<VarDeclBase>(); - if (!varDecl) - continue; - - collectGlobalScopeParameter(context, varDecl); + if (auto genParamDecl = decl.As<GlobalGenericParamDecl>()) + collectGlobalGenericParameter(context, genParamDecl); + } + for (auto decl : program->Members) + { + if (auto varDecl = decl.As<VarDeclBase>()) + collectGlobalScopeParameter(context, varDecl); } // Next, we need to enumerate the parameters of @@ -1665,7 +1694,8 @@ void generateParameterBindings( if (!layoutContext.rules) return; - RefPtr<ProgramLayout> programLayout = new ProgramLayout; + RefPtr<ProgramLayout> programLayout = new ProgramLayout(); + targetReq->layout = programLayout; // Create a context to hold shared state during the process // of generating parameter bindings @@ -1680,7 +1710,6 @@ void generateParameterBindings( context.shared = &sharedContext; context.translationUnit = nullptr; context.layoutContext = layoutContext; - // Walk through AST to discover all the parameters collectParameters(&context, compileReq); @@ -1707,6 +1736,7 @@ void generateParameterBindings( // If there are any global-scope uniforms, then we need to // allocate a constant-buffer binding for them here. ParameterBindingInfo globalConstantBufferBinding; + globalConstantBufferBinding.index = 0; if( anyGlobalUniforms ) { // TODO: this logic is only correct for D3D targets, where @@ -1838,8 +1868,191 @@ void generateParameterBindings( // We now have a bunch of layout information, which we should // record into a suitable object that represents the program - programLayout->globalScopeLayout = globalScopeLayout; - targetReq->layout = programLayout; + RefPtr<VarLayout> globalVarLayout = new VarLayout(); + globalVarLayout->typeLayout = globalScopeLayout; + if (anyGlobalUniforms) + { + auto cbInfo = globalVarLayout->findOrAddResourceInfo(LayoutResourceKind::ConstantBuffer); + cbInfo->space = 0; + cbInfo->index = globalConstantBufferBinding.index; + } + programLayout->globalScopeLayout = globalVarLayout; } +StructTypeLayout* getGlobalStructLayout( + ProgramLayout* programLayout); + +RefPtr<ProgramLayout> specializeProgramLayout( + TargetRequest * targetReq, + ProgramLayout* programLayout, + Substitutions * typeSubst) +{ + RefPtr<ProgramLayout> newProgramLayout; + newProgramLayout = new ProgramLayout(); + newProgramLayout->bindingForHackSampler = programLayout->bindingForHackSampler; + newProgramLayout->hackSamplerVar = programLayout->hackSamplerVar; + for (auto & entryPoint : programLayout->entryPoints) + { + RefPtr<EntryPointLayout> newEntryPoint = new EntryPointLayout(*entryPoint); + // TODO: for now just copy existing entry point layouts, but we eventually need to + // specialize these as well... + newProgramLayout->entryPoints.Add(newEntryPoint); + } + + List<RefPtr<TypeLayout>> paramTypeLayouts; + auto globalStructLayout = getGlobalStructLayout(programLayout); + SLANG_ASSERT(globalStructLayout); + RefPtr<StructTypeLayout> structLayout = new StructTypeLayout(); + RefPtr<TypeLayout> globalScopeLayout = structLayout; + structLayout->uniformAlignment = globalStructLayout->uniformAlignment; + + // Try to find rules based on the selected code-generation target + auto layoutContext = getInitialLayoutContextForTarget(targetReq); + + // If there was no target, or there are no rules for the target, + // then bail out here. + if (!layoutContext.rules) + return newProgramLayout; + + + // we need to initialize a layout context to mark used registers + SharedParameterBindingContext sharedContext; + sharedContext.compileRequest = targetReq->compileRequest; + sharedContext.defaultLayoutRules = layoutContext.getRulesFamily(); + sharedContext.programLayout = programLayout; + + // Create a sub-context to collect parameters that get + // declared into the global scope + ParameterBindingContext context; + context.shared = &sharedContext; + context.translationUnit = nullptr; + context.layoutContext = layoutContext; + + auto constantBufferRules = context.getRulesFamily()->getConstantBufferRules(); + structLayout->rules = constantBufferRules; + + UniformLayoutInfo structLayoutInfo; + structLayoutInfo.alignment = globalStructLayout->uniformAlignment; + structLayoutInfo.size = 0; + bool anyUniforms = false; + Dictionary<RefPtr<VarLayout>, RefPtr<VarLayout>> varLayoutMapping; + for (auto & varLayout : globalStructLayout->fields) + { + // To recover layout context, we skip generic resources in the first pass + // If the var is a generic resource, its resourceInfos will be empty. + if (varLayout->resourceInfos.Count() == 0) + continue; + SLANG_ASSERT(varLayout->resourceInfos.Count() == varLayout->typeLayout->resourceInfos.Count()); + auto uniformInfo = varLayout->FindResourceInfo(LayoutResourceKind::Uniform); + auto tUniformInfo = varLayout->typeLayout->FindResourceInfo(LayoutResourceKind::Uniform); + if (uniformInfo) + { + anyUniforms = true; + SLANG_ASSERT(tUniformInfo); + structLayoutInfo.size = Math::Max(structLayoutInfo.size, uniformInfo->index + tUniformInfo->count); + } + for (UInt i = 0; i < varLayout->resourceInfos.Count(); i++) + { + auto resInfo = varLayout->resourceInfos[i]; + auto tresInfo = varLayout->typeLayout->resourceInfos[i]; + SLANG_ASSERT(resInfo.kind == tresInfo.kind); + auto usedRangeSet = findUsedRangeSetForSpace(&context, resInfo.space); + markSpaceUsed(&context, resInfo.space); + usedRangeSet->usedResourceRanges[(int)resInfo.kind].Add( + nullptr, // we don't need to track parameter info here + resInfo.index, + resInfo.index + varLayout->typeLayout->resourceInfos[0].count); + } + structLayout->fields.Add(varLayout); + varLayoutMapping[varLayout] = varLayout; + } + auto originalGlobalCBufferInfo = programLayout->globalScopeLayout->FindResourceInfo(LayoutResourceKind::ConstantBuffer); + VarLayout::ResourceInfo globalCBufferInfo; + globalCBufferInfo.kind = LayoutResourceKind::None; + globalCBufferInfo.space = 0; + globalCBufferInfo.index = 0; + if (originalGlobalCBufferInfo) + { + globalCBufferInfo.kind = LayoutResourceKind::ConstantBuffer; + globalCBufferInfo.space = originalGlobalCBufferInfo->space; + globalCBufferInfo.index = originalGlobalCBufferInfo->index; + } + // we have the context restored, can continue to layout the generic variables now + for (auto & varLayout : globalStructLayout->fields) + { + if (varLayout->typeLayout->FindResourceInfo(LayoutResourceKind::GenericResource)) + { + RefPtr<Type> newType = varLayout->typeLayout->type->Substitute(typeSubst).As<Type>(); + RefPtr<TypeLayout> newTypeLayout = CreateTypeLayout( + layoutContext.with(constantBufferRules), + newType); + auto layoutInfo = newTypeLayout->FindResourceInfo(LayoutResourceKind::Uniform); + size_t uniformSize = layoutInfo ? layoutInfo->count : 0; + if (uniformSize) + { + if (globalCBufferInfo.kind == LayoutResourceKind::None) + { + // user defined a uniform via a global generic type argument + // but we have not reserved a binding for the global uniform buffer + UInt space = 0; + auto usedRangeSet = findUsedRangeSetForSpace(&context, space); + globalCBufferInfo.kind = LayoutResourceKind::ConstantBuffer; + globalCBufferInfo.index = + usedRangeSet->usedResourceRanges[ + (int)LayoutResourceKind::ConstantBuffer].Allocate(nullptr, 1); + globalCBufferInfo.space = space; + } + } + RefPtr<VarLayout> newVarLayout = new VarLayout(); + RefPtr<ParameterInfo> paramInfo = new ParameterInfo(); + newVarLayout->varDecl = varLayout->varDecl; + newVarLayout->typeLayout = newTypeLayout; + paramInfo->varLayouts.Add(newVarLayout); + completeBindingsForParameter(&context, paramInfo); + // update uniform layout + + if (uniformSize != 0) + { + // Make sure uniform fields get laid out properly... + UniformLayoutInfo fieldInfo( + uniformSize, + newTypeLayout->uniformAlignment); + size_t uniformOffset = layoutContext.getRulesFamily()->getConstantBufferRules()->AddStructField( + &structLayoutInfo, + fieldInfo); + newVarLayout->findOrAddResourceInfo(LayoutResourceKind::Uniform)->index = uniformOffset; + anyUniforms = true; + } + structLayout->fields.Add(newVarLayout); + varLayoutMapping[varLayout] = newVarLayout; + } + } + for (auto mapping : globalStructLayout->mapVarToLayout) + { + RefPtr<VarLayout> updatedVarLayout = mapping.Value; + varLayoutMapping.TryGetValue(updatedVarLayout, updatedVarLayout); + structLayout->mapVarToLayout[mapping.Key] = updatedVarLayout; + } + + // If there are global-scope uniforms, then we need to wrap + // up a global constant buffer type layout to hold them + RefPtr<VarLayout> globalVarLayout = new VarLayout(); + if (anyUniforms) + { + auto globalConstantBufferLayout = createParameterGroupTypeLayout( + layoutContext, + nullptr, + constantBufferRules, + constantBufferRules->GetObjectLayout(ShaderParameterKind::ConstantBuffer), + structLayout); + + globalScopeLayout = globalConstantBufferLayout; + auto cbInfo = globalVarLayout->findOrAddResourceInfo(LayoutResourceKind::ConstantBuffer); + *cbInfo = globalCBufferInfo; + } + globalVarLayout->typeLayout = globalScopeLayout; + programLayout->globalScopeLayout = globalVarLayout; + newProgramLayout->globalScopeLayout = globalVarLayout; + return newProgramLayout; +} } diff --git a/source/slang/parser.cpp b/source/slang/parser.cpp index 42c763099..0a4360e3f 100644 --- a/source/slang/parser.cpp +++ b/source/slang/parser.cpp @@ -95,9 +95,9 @@ namespace Slang RefPtr<StructDecl> ParseStruct(); RefPtr<ClassDecl> ParseClass(); RefPtr<Stmt> ParseStatement(); - RefPtr<Stmt> ParseBlockStatement(); - RefPtr<DeclStmt> ParseVarDeclrStatement(Modifiers modifiers); - RefPtr<IfStmt> ParseIfStatement(); + RefPtr<Stmt> parseBlockStatement(); + RefPtr<DeclStmt> parseVarDeclrStatement(Modifiers modifiers); + RefPtr<IfStmt> parseIfStatement(); RefPtr<ForStmt> ParseForStatement(); RefPtr<WhileStmt> ParseWhileStatement(); RefPtr<DoWhileStmt> ParseDoWhileStatement(); @@ -1034,7 +1034,7 @@ namespace Slang } else { - decl->Body = parser->ParseBlockStatement(); + decl->Body = parser->parseBlockStatement(); } parser->PopScope(); @@ -2172,41 +2172,55 @@ namespace Slang } } - RefPtr<RefObject> ParseAssocType(Parser * parser, void *) + void parseOptionalGenericConstraints(Parser * parser, ContainerDecl* decl) { - RefPtr<AssocTypeDecl> assocTypeDecl = new AssocTypeDecl(); - - auto nameToken = parser->ReadToken(TokenType::Identifier); - assocTypeDecl->nameAndLoc = NameLoc(nameToken); - assocTypeDecl->loc = nameToken.loc; if (AdvanceIf(parser, TokenType::Colon)) { - while (!parser->tokenReader.IsAtEnd()) + do { - auto paramConstraint = new GenericTypeConstraintDecl(); + RefPtr<GenericTypeConstraintDecl> paramConstraint = new GenericTypeConstraintDecl(); parser->FillPosition(paramConstraint); - auto paramType = DeclRefType::Create( + RefPtr<DeclRefType> paramType = DeclRefType::Create( parser->getSession(), - DeclRef<Decl>(assocTypeDecl, nullptr)); + DeclRef<Decl>(decl, nullptr)); - auto paramTypeExpr = new SharedTypeExpr(); - paramTypeExpr->loc = assocTypeDecl->loc; + RefPtr<SharedTypeExpr> paramTypeExpr = new SharedTypeExpr(); + paramTypeExpr->loc = decl->loc; paramTypeExpr->base.type = paramType; paramTypeExpr->type = QualType(getTypeType(paramType)); paramConstraint->sub = TypeExp(paramTypeExpr); paramConstraint->sup = parser->ParseTypeExp(); - AddMember(assocTypeDecl, paramConstraint); - if (!AdvanceIf(parser, TokenType::Comma)) - break; - } + AddMember(decl, paramConstraint); + } while (AdvanceIf(parser, TokenType::Comma)); } + } + + RefPtr<RefObject> parseAssocType(Parser * parser, void *) + { + RefPtr<AssocTypeDecl> assocTypeDecl = new AssocTypeDecl(); + + auto nameToken = parser->ReadToken(TokenType::Identifier); + assocTypeDecl->nameAndLoc = NameLoc(nameToken); + assocTypeDecl->loc = nameToken.loc; + parseOptionalGenericConstraints(parser, assocTypeDecl); parser->ReadToken(TokenType::Semicolon); return assocTypeDecl; } + RefPtr<RefObject> parseGlobalGenericParamDecl(Parser * parser, void *) + { + RefPtr<GlobalGenericParamDecl> genParamDecl = new GlobalGenericParamDecl(); + auto nameToken = parser->ReadToken(TokenType::Identifier); + genParamDecl->nameAndLoc = NameLoc(nameToken); + genParamDecl->loc = nameToken.loc; + parseOptionalGenericConstraints(parser, genParamDecl); + parser->ReadToken(TokenType::Semicolon); + return genParamDecl; + } + static RefPtr<RefObject> parseInterfaceDecl(Parser* parser, void* /*userData*/) { RefPtr<InterfaceDecl> decl = new InterfaceDecl(); @@ -2220,7 +2234,7 @@ namespace Slang return decl; } - static RefPtr<RefObject> ParseConstructorDecl(Parser* parser, void* /*userData*/) + static RefPtr<RefObject> parseConstructorDecl(Parser* parser, void* /*userData*/) { RefPtr<ConstructorDecl> decl = new ConstructorDecl(); parser->FillPosition(decl.Ptr()); @@ -2243,7 +2257,7 @@ namespace Slang } else { - decl->Body = parser->ParseBlockStatement(); + decl->Body = parser->parseBlockStatement(); } return decl; } @@ -2271,7 +2285,7 @@ namespace Slang if( parser->tokenReader.PeekTokenType() == TokenType::LBrace ) { - decl->Body = parser->ParseBlockStatement(); + decl->Body = parser->parseBlockStatement(); } else { @@ -2664,7 +2678,7 @@ namespace Slang parser->ReadToken(TokenType::LParent); stmt->condition = parser->ParseExpression(); parser->ReadToken(TokenType::RParent); - stmt->body = parser->ParseBlockStatement(); + stmt->body = parser->parseBlockStatement(); return stmt; } @@ -2788,11 +2802,11 @@ namespace Slang RefPtr<Stmt> statement; if (LookAheadToken(TokenType::LBrace)) - statement = ParseBlockStatement(); + statement = parseBlockStatement(); else if (peekTypeName(this)) - statement = ParseVarDeclrStatement(modifiers); + statement = parseVarDeclrStatement(modifiers); else if (LookAheadToken("if")) - statement = ParseIfStatement(); + statement = parseIfStatement(); else if (LookAheadToken("for")) statement = ParseForStatement(); else if (LookAheadToken("while")) @@ -2852,7 +2866,7 @@ namespace Slang // Note: the declaration will consume any modifiers // that had been in place on the statement. tokenReader.mCursor = startPos; - statement = ParseVarDeclrStatement(modifiers); + statement = parseVarDeclrStatement(modifiers); return statement; } @@ -2885,7 +2899,7 @@ namespace Slang return statement; } - RefPtr<Stmt> Parser::ParseBlockStatement() + RefPtr<Stmt> Parser::parseBlockStatement() { // If we are being asked not to check things *and* we haven't // seen any `import` declarations yet, then we can safely assume @@ -2983,7 +2997,7 @@ namespace Slang return blockStatement; } - RefPtr<DeclStmt> Parser::ParseVarDeclrStatement( + RefPtr<DeclStmt> Parser::parseVarDeclrStatement( Modifiers modifiers) { RefPtr<DeclStmt>varDeclrStatement = new DeclStmt(); @@ -2994,7 +3008,7 @@ namespace Slang return varDeclrStatement; } - RefPtr<IfStmt> Parser::ParseIfStatement() + RefPtr<IfStmt> Parser::parseIfStatement() { RefPtr<IfStmt> ifStatement = new IfStmt(); FillPosition(ifStatement.Ptr()); @@ -3045,7 +3059,7 @@ namespace Slang ReadToken(TokenType::LParent); if (peekTypeName(this)) { - stmt->InitialStatement = ParseVarDeclrStatement(Modifiers()); + stmt->InitialStatement = parseVarDeclrStatement(Modifiers()); } else { @@ -3107,7 +3121,7 @@ namespace Slang return breakStatement; } - RefPtr<ContinueStmt> Parser::ParseContinueStatement() + RefPtr<ContinueStmt> Parser::ParseContinueStatement() { RefPtr<ContinueStmt> continueStatement = new ContinueStmt(); FillPosition(continueStatement.Ptr()); @@ -4128,17 +4142,18 @@ namespace Slang // Add syntax for declaration keywords #define DECL(KEYWORD, CALLBACK) \ addBuiltinSyntax<Decl>(session, scope, #KEYWORD, &CALLBACK) - DECL(typedef, ParseTypeDef); - DECL(associatedtype,ParseAssocType); - DECL(cbuffer, parseHLSLCBufferDecl); - DECL(tbuffer, parseHLSLTBufferDecl); - DECL(__generic, ParseGenericDecl); - DECL(__extension, ParseExtensionDecl); - DECL(__init, ParseConstructorDecl); - DECL(__subscript, ParseSubscriptDecl); - DECL(interface, parseInterfaceDecl); - DECL(syntax, parseSyntaxDecl); - DECL(__import, parseImportDecl); + DECL(typedef, ParseTypeDef); + DECL(associatedtype, parseAssocType); + DECL(__generic_param, parseGlobalGenericParamDecl); + DECL(cbuffer, parseHLSLCBufferDecl); + DECL(tbuffer, parseHLSLTBufferDecl); + DECL(__generic, ParseGenericDecl); + DECL(__extension, ParseExtensionDecl); + DECL(__init, parseConstructorDecl); + DECL(__subscript, ParseSubscriptDecl); + DECL(interface, parseInterfaceDecl); + DECL(syntax, parseSyntaxDecl); + DECL(__import, parseImportDecl); #undef DECL diff --git a/source/slang/reflection.cpp b/source/slang/reflection.cpp index 9fc032c76..14199f126 100644 --- a/source/slang/reflection.cpp +++ b/source/slang/reflection.cpp @@ -886,3 +886,22 @@ SLANG_API SlangReflectionEntryPoint* spReflection_getEntryPointByIndex(SlangRefl return convert(program->entryPoints[(int) index].Ptr()); } + +SLANG_API SlangUInt spReflection_getGlobalConstantBufferBinding(SlangReflection* inProgram) +{ + auto program = convert(inProgram); + if (!program) return 0; + auto cb = program->globalScopeLayout->FindResourceInfo(LayoutResourceKind::ConstantBuffer); + if (!cb) return 0; + return cb->index; +} + +SLANG_API size_t spReflection_getGlobalConstantBufferSize(SlangReflection* inProgram) +{ + auto program = convert(inProgram); + if (!program) return 0; + auto structLayout = getGlobalStructLayout(program); + auto uniform = structLayout->FindResourceInfo(LayoutResourceKind::Uniform); + if (!uniform) return 0; + return uniform->count; +} diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index 17f8ea96d..6a103fc2d 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -400,14 +400,16 @@ void CompileRequest::addTranslationUnitSourceFile( int CompileRequest::addEntryPoint( int translationUnitIndex, String const& name, - Profile entryPointProfile) + Profile entryPointProfile, + List<String> const & genericTypeNames) { RefPtr<EntryPointRequest> entryPoint = new EntryPointRequest(); entryPoint->compileRequest = this; entryPoint->name = getNamePool()->getName(name); entryPoint->profile = entryPointProfile; entryPoint->translationUnitIndex = translationUnitIndex; - + for (auto typeName : genericTypeNames) + entryPoint->genericParameterTypeNames.Add(getNamePool()->getName(typeName)); auto translationUnit = translationUnits[translationUnitIndex].Ptr(); translationUnit->entryPoints.Add(entryPoint); @@ -909,7 +911,31 @@ SLANG_API int spAddEntryPoint( return req->addEntryPoint( translationUnitIndex, name, - Slang::Profile(Slang::Profile::RawVal(profile))); + Slang::Profile(Slang::Profile::RawVal(profile)), + Slang::List<Slang::String>()); +} + +SLANG_API int spAddEntryPointEx( + SlangCompileRequest* request, + int translationUnitIndex, + char const* name, + SlangProfileID profile, + int genericParamTypeNameCount, + char const ** genericParamTypeNames) +{ + if (!request) return -1; + auto req = REQ(request); + if (!name) return -1; + if (translationUnitIndex < 0) return -1; + if (Slang::UInt(translationUnitIndex) >= req->translationUnits.Count()) return -1; + Slang::List<Slang::String> typeNames; + for (int i = 0; i < genericParamTypeNameCount; i++) + typeNames.Add(genericParamTypeNames[i]); + return req->addEntryPoint( + translationUnitIndex, + name, + Slang::Profile(Slang::Profile::RawVal(profile)), + typeNames); } diff --git a/source/slang/syntax-base-defs.h b/source/slang/syntax-base-defs.h index 3c7e8c5ae..fdb2694a9 100644 --- a/source/slang/syntax-base-defs.h +++ b/source/slang/syntax-base-defs.h @@ -197,6 +197,38 @@ SYNTAX_CLASS(ThisTypeSubstitution, Substitutions) ) END_SYNTAX_CLASS() +SYNTAX_CLASS(GlobalGenericParamSubstitution, Substitutions) + // the __generic_param decl to be substituted + DECL_FIELD(GlobalGenericParamDecl*, paramDecl) + // the actual type to substitute in + SYNTAX_FIELD(RefPtr<Val>, actualType) + +RAW( + // Apply a set of substitutions to the bindings in this substitution + virtual RefPtr<Substitutions> SubstituteImpl(Substitutions* subst, int* ioDiff) override; + + // Check if these are equivalent substitutiosn to another set + virtual bool Equals(Substitutions* subst) override; + virtual bool operator == (const Substitutions & subst) override + { + return Equals(const_cast<Substitutions*>(&subst)); + } + virtual int GetHashCode() const override + { + int rs = actualType->GetHashCode(); + for (auto && v : witnessTables) + { + rs = combineHash(rs, v.Key->GetHashCode()); + rs = combineHash(rs, v.Value->GetHashCode()); + } + return rs; + } + typedef List<KeyValuePair<RefPtr<Type>, RefPtr<Val>>> WitnessTableLookupTable; +) + // The witness tables for each interface this actual type implements + SYNTAX_FIELD(WitnessTableLookupTable, witnessTables) +END_SYNTAX_CLASS() + ABSTRACT_SYNTAX_CLASS(SyntaxNode, SyntaxNodeBase) END_SYNTAX_CLASS() diff --git a/source/slang/syntax.cpp b/source/slang/syntax.cpp index e5fc8dfa3..fa9c88051 100644 --- a/source/slang/syntax.cpp +++ b/source/slang/syntax.cpp @@ -93,6 +93,7 @@ ABSTRACT_SYNTAX_CLASS(Expr, SyntaxNode); ABSTRACT_SYNTAX_CLASS(Substitutions, SyntaxNode); ABSTRACT_SYNTAX_CLASS(GenericSubstitution, Substitutions); ABSTRACT_SYNTAX_CLASS(ThisTypeSubstitution, Substitutions); +ABSTRACT_SYNTAX_CLASS(GlobalGenericParamSubstitution, Substitutions); #include "expr-defs.h" #include "decl-defs.h" @@ -488,6 +489,20 @@ void Type::accept(IValVisitor* visitor, void* extra) } } } + else if (auto globalGenParam = dynamic_cast<GlobalGenericParamDecl*>(declRef.getDecl())) + { + // search for a substitution that might apply to us + for (auto s = subst; s; s = s->outer.Ptr()) + { + if (auto genericSubst = dynamic_cast<GlobalGenericParamSubstitution*>(s)) + { + if (genericSubst->paramDecl == globalGenParam) + { + return genericSubst->actualType; + } + } + } + } int diff = 0; DeclRef<Decl> substDeclRef = declRef.SubstituteImpl(subst, &diff); @@ -1208,6 +1223,35 @@ void Type::accept(IValVisitor* visitor, void* extra) return false; } + RefPtr<Substitutions> GlobalGenericParamSubstitution::SubstituteImpl(Substitutions* /*subst*/, int* /*ioDiff*/) + { + // we will never replace values for this type of substitution + return this; + } + + bool GlobalGenericParamSubstitution::Equals(Substitutions* subst) + { + if (!subst) + return false; + if (auto genSubst = dynamic_cast<GlobalGenericParamSubstitution*>(subst)) + { + if (paramDecl != genSubst->paramDecl) + return false; + if (!actualType->EqualsVal(genSubst->actualType)) + return false; + if (witnessTables.Count() != genSubst->witnessTables.Count()) + return false; + for (UInt i = 0; i < witnessTables.Count(); i++) + { + if (!witnessTables[i].Key->Equals(genSubst->witnessTables[i].Key)) + return false; + if (!witnessTables[i].Value->EqualsVal(genSubst->witnessTables[i].Value)) + return false; + } + return true; + } + return false; + } // DeclRefBase @@ -1564,6 +1608,24 @@ void Type::accept(IValVisitor* visitor, void* extra) return genericSubst->args[index + ordinaryParamCount]; } } + else if (auto globalGenParamSubst = dynamic_cast<GlobalGenericParamSubstitution*>(s)) + { + // we have a GlobalGenericParamSubstitution, this substitution will provide + // a concrete IRWitnessTable for a generic global variable + auto supType = GetSup(genConstraintDecl); + + // check if the substitution is really about this global generic type parameter + if (globalGenParamSubst->paramDecl != genConstraintDecl.getDecl()->ParentDecl) + continue; + + // find witness table for the required interface + for (auto witness : globalGenParamSubst->witnessTables) + if (witness.Key->EqualsVal(supType)) + { + (*ioDiff)++; + return witness.Value; + } + } } } RefPtr<DeclaredSubtypeWitness> rs = new DeclaredSubtypeWitness(); diff --git a/source/slang/syntax.h b/source/slang/syntax.h index 46beca2d9..b4d550ef5 100644 --- a/source/slang/syntax.h +++ b/source/slang/syntax.h @@ -1073,7 +1073,7 @@ namespace Slang { return declRef.Substitute(declRef.getDecl()->base.type); } - + inline RefPtr<Type> GetType(DeclRef<TypeDefDecl> const& declRef) { return declRef.Substitute(declRef.getDecl()->type.Ptr()); diff --git a/source/slang/type-layout.cpp b/source/slang/type-layout.cpp index 8fa790dd8..30b2ee01a 100644 --- a/source/slang/type-layout.cpp +++ b/source/slang/type-layout.cpp @@ -1222,6 +1222,11 @@ SimpleLayoutInfo GetLayoutImpl( return GetLayoutImpl(subContext, type, outTypeLayout, SimpleLayoutInfo()); } +int findGenericParam(List<RefPtr<GenericParamLayout>> & genericParameters, GlobalGenericParamDecl * decl) +{ + return (int)genericParameters.FindFirst([=](RefPtr<GenericParamLayout> & x) {return x->decl.Ptr() == decl; }); +} + SimpleLayoutInfo GetLayoutImpl( TypeLayoutContext const& context, Type* type, @@ -1599,6 +1604,25 @@ SimpleLayoutInfo GetLayoutImpl( return info; } + else if (auto globalGenParam = declRef.As<GlobalGenericParamDecl>()) + { + SimpleLayoutInfo info; + info.alignment = 0; + info.size = 0; + info.kind = LayoutResourceKind::GenericResource; + if (outTypeLayout) + { + auto genParamTypeLayout = new GenericParamTypeLayout(); + // we should have already populated ProgramLayout::genericEntryPointParams list at this point, + // so we can find the index of this generic param decl in the list + genParamTypeLayout->type = type; + genParamTypeLayout->paramIndex = findGenericParam(context.targetReq->layout->globalGenericParams, genParamTypeLayout->getGlobalGenericParamDecl()); + genParamTypeLayout->rules = rules; + genParamTypeLayout->findOrAddResourceInfo(LayoutResourceKind::GenericResource)->count++; + *outTypeLayout = genParamTypeLayout; + } + return info; + } } else if (auto errorType = type->As<ErrorType>()) { @@ -1667,4 +1691,12 @@ RefPtr<TypeLayout> CreateTypeLayout( return CreateTypeLayout(context, type, SimpleLayoutInfo()); } +RefPtr<GlobalGenericParamDecl> GenericParamTypeLayout::getGlobalGenericParamDecl() +{ + auto declRefType = type->AsDeclRefType(); + SLANG_ASSERT(declRefType); + auto rsDeclRef = declRefType->declRef.As<GlobalGenericParamDecl>(); + return rsDeclRef.getDecl(); +} + } // namespace Slang diff --git a/source/slang/type-layout.h b/source/slang/type-layout.h index 363b01486..4ce6dc355 100644 --- a/source/slang/type-layout.h +++ b/source/slang/type-layout.h @@ -220,7 +220,7 @@ typedef unsigned int VarLayoutFlags; enum VarLayoutFlag : VarLayoutFlags { IsRedeclaration = 1 << 0, ///< This is a redeclaration of some shader parameter - HasSemantic = 1 << 1, + HasSemantic = 1 << 1 }; // A reified layout for a particular variable, field, etc. @@ -358,6 +358,13 @@ public: Dictionary<Decl*, RefPtr<VarLayout>> mapVarToLayout; }; +class GenericParamTypeLayout : public TypeLayout +{ +public: + RefPtr<GlobalGenericParamDecl> getGlobalGenericParamDecl(); + int paramIndex = 0; +}; + // Layout information for a single shader entry point // within a program // @@ -386,6 +393,13 @@ public: unsigned flags = 0; }; +class GenericParamLayout : public Layout +{ +public: + RefPtr<GlobalGenericParamDecl> decl; + int index; +}; + // Layout information for the global scope of a program class ProgramLayout : public Layout { @@ -403,13 +417,15 @@ public: // (since a constant buffer will have to be allocated // to store them). // - RefPtr<TypeLayout> globalScopeLayout; + RefPtr<VarLayout> globalScopeLayout; // We catalog the requested entry points here, // and any entry-point-specific parameter data // will (eventually) belong there... List<RefPtr<EntryPointLayout>> entryPoints; + List<RefPtr<GenericParamLayout>> globalGenericParams; + // HACK: binding to use when we have to create // a dummy sampler just to appease glslang int bindingForHackSampler = 0; |
