diff options
| author | Yong He <yonghe@outlook.com> | 2018-01-21 10:48:31 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2018-01-21 10:48:31 -0800 |
| commit | 4044a1d3a0605198465a7eb6e0e3c1f8b1a3c298 (patch) | |
| tree | 62927d4d2722b36c8e7eb4060e741b9032686835 | |
| parent | 2079b941bc5849b6ab33774fb90cefe9c2d624cb (diff) | |
| parent | f681a1505c98995683a7fbae7ce208dc5e444b9b (diff) | |
Merge pull request #372 from csyonghe/master
Allow type expression as type argument, fix global param enum order
23 files changed, 376 insertions, 126 deletions
diff --git a/source/slang/check.cpp b/source/slang/check.cpp index dfc09c485..52558ee15 100644 --- a/source/slang/check.cpp +++ b/source/slang/check.cpp @@ -6849,6 +6849,23 @@ namespace Slang return (!decl->primaryDecl) || (decl == decl->primaryDecl); } + RefPtr<Type> checkProperType(TranslationUnitRequest * tu, TypeExp typeExp) + { + RefPtr<Type> type; + DiagnosticSink nSink; + nSink.sourceManager = tu->compileRequest->sourceManager; + SemanticsVisitor visitor( + &nSink, + tu->compileRequest, + tu); + auto typeOut = visitor.CheckProperType(typeExp); + if (!nSink.errorCount) + { + type = typeOut.type; + } + return type; + } + void validateEntryPoint( EntryPointRequest* entryPoint) { @@ -6944,26 +6961,25 @@ namespace Slang entryPoint->decl = entryPointFuncDecl; // Lookup generic parameter types in global scope + List<RefPtr<Scope>> scopesToTry; + scopesToTry.Add(entryPoint->getTranslationUnit()->SyntaxNode->scope); + for (auto & module : entryPoint->compileRequest->loadedModulesList) + scopesToTry.Add(module->moduleDecl->scope); for (auto name : entryPoint->genericParameterTypeNames) - { - firstDeclWithName = entryPoint->compileRequest->lookupGlobalDecl(name); - if (!firstDeclWithName) - { - // If there doesn't appear to be any such declaration, then - // we need to diagnose it as an error, and then bail out. - sink->diagnose(translationUnitSyntax, Diagnostics::entryPointTypeParameterNotFound, name); - return; - } + { + // parse type name RefPtr<Type> type; - if (auto aggType = firstDeclWithName->As<AggTypeDecl>()) - { - type = DeclRefType::Create(entryPoint->compileRequest->mSession, DeclRef<Decl>(aggType, nullptr)); - } - else if (auto typeDefDecl = firstDeclWithName->As<TypeDefDecl>()) + for (auto & s : scopesToTry) { - type = GetType(DeclRef<TypeDefDecl>(typeDefDecl, nullptr)); + RefPtr<Expr> typeExpr = entryPoint->compileRequest->parseTypeString(entryPoint->getTranslationUnit(), + name, s); + type = checkProperType(translationUnit, TypeExp(typeExpr)); + if (type) + { + break; + } } - else + if (!type) { sink->diagnose(firstDeclWithName, Diagnostics::entryPointTypeSymbolNotAType, name); return; diff --git a/source/slang/compiler.h b/source/slang/compiler.h index 1fca4751c..960e67ffe 100644 --- a/source/slang/compiler.h +++ b/source/slang/compiler.h @@ -104,7 +104,7 @@ namespace Slang // The type names we want to substitute into the // global generic type parameters - List<Name*> genericParameterTypeNames; + List<String> genericParameterTypeNames; // The profile that the entry point will be compiled for // (this is a combination of the target state, and also @@ -318,6 +318,10 @@ namespace Slang ~CompileRequest(); + RefPtr<Expr> parseTypeString(TranslationUnitRequest * translationUnit, String typeStr, RefPtr<Scope> scope); + + Type* getTypeFromString(String typeStr); + void parseTranslationUnit( TranslationUnitRequest* translationUnit); diff --git a/source/slang/decl-defs.h b/source/slang/decl-defs.h index fb35e327a..8e1985e3f 100644 --- a/source/slang/decl-defs.h +++ b/source/slang/decl-defs.h @@ -196,7 +196,9 @@ SIMPLE_SYNTAX_CLASS(Variable, VarDeclBase); // A "module" of code (essentiately, a single translation unit) // that provides a scope for some number of declarations. -SIMPLE_SYNTAX_CLASS(ModuleDecl, ContainerDecl) +SYNTAX_CLASS(ModuleDecl, ContainerDecl) + FIELD(RefPtr<Scope>, scope) +END_SYNTAX_CLASS() SYNTAX_CLASS(ImportDecl, Decl) // The name of the module we are trying to import diff --git a/source/slang/emit.cpp b/source/slang/emit.cpp index 53f02cc56..18216de81 100644 --- a/source/slang/emit.cpp +++ b/source/slang/emit.cpp @@ -7500,7 +7500,7 @@ String emitEntryPoint( // none of our target supports generics, or interfaces, // so we need to specialize those away. // - specializeGenerics(irModule); + specializeGenerics(irModule, sharedContext.target); // Debugging code for IR transformations... #if 0 diff --git a/source/slang/ir-insts.h b/source/slang/ir-insts.h index 23e948b3a..dedc906d0 100644 --- a/source/slang/ir-insts.h +++ b/source/slang/ir-insts.h @@ -641,7 +641,8 @@ void specializeIRForEntryPoint( // Find suitable uses of the `specialize` instruction that // can be replaced with references to specialized functions. void specializeGenerics( - IRModule* module); + IRModule* module, + CodeGenTarget target); // diff --git a/source/slang/ir.cpp b/source/slang/ir.cpp index 994ac82ff..7318bff4c 100644 --- a/source/slang/ir.cpp +++ b/source/slang/ir.cpp @@ -7,6 +7,13 @@ namespace Slang { + struct IRSpecContext; + + IRGlobalValue* cloneGlobalValueWithMangledName( + IRSpecContext* context, + String const& mangledName, + IRGlobalValue* originalVal); + static const IROpInfo kIROpInfos[] = { @@ -3065,6 +3072,9 @@ namespace Slang struct IRSharedSpecContext { + // The code-generation target in use + CodeGenTarget target; + // The specialized module we are building IRModule* module; @@ -3091,6 +3101,10 @@ namespace Slang struct IRSpecContextBase { + // A map from the mangled name of a global variable + // to the layout to use for it. + Dictionary<String, VarLayout*> globalVarLayouts; + IRSharedSpecContext* shared; IRSharedSpecContext* getShared() { return shared; } @@ -3224,13 +3238,6 @@ namespace Slang struct IRSpecContext : IRSpecContextBase { - // The code-generation target in use - CodeGenTarget target; - - // A map from the mangled name of a global variable - // to the layout to use for it. - Dictionary<String, VarLayout*> globalVarLayouts; - // Override the "maybe clone" logic so that we always clone virtual IRValue* maybeCloneValue(IRValue* originalVal) override; @@ -3434,18 +3441,31 @@ namespace Slang return newDeclRef; } - IRValue* cloneValue( IRSpecContextBase* context, IRValue* originalValue) { IRValue* clonedValue = nullptr; if (context->getClonedValues().TryGetValue(originalValue, clonedValue)) + { return clonedValue; + } return context->maybeCloneValue(originalValue); } + IRValue* maybeCloneValueWithMangledName( + IRSpecContextBase* context, + IRGlobalValue* originalValue) + { + for (auto gv = context->shared->module->firstGlobalValue; gv; gv = gv->nextGlobalValue) + { + if (gv->mangledName == originalValue->mangledName) + return gv; + } + return cloneValue(context, originalValue); + } + void cloneInst( IRSpecContextBase* context, IRBuilder* builder, @@ -3468,18 +3488,23 @@ namespace Slang context->maybeCloneType(originalInst->type), 0, nullptr, argCount, nullptr); - builder->addInst(clonedInst); registerClonedValue(context, clonedInst, originalInst); - - cloneDecorations(context, clonedInst, originalInst); - + auto oldBuilder = context->builder; + context->builder = builder; for (UInt aa = 0; aa < argCount; ++aa) { IRValue* originalArg = originalInst->getArg(aa); - IRValue* clonedArg = cloneValue(context, originalArg); - + IRValue* clonedArg; + if (originalArg->op == kIROp_witness_table) + clonedArg = cloneGlobalValueWithMangledName((IRSpecContext*)context, + ((IRGlobalValue*)originalArg)->mangledName, (IRGlobalValue*)originalArg); + else + clonedArg = cloneValue(context, originalArg); clonedInst->getArgs()[aa].init(clonedInst, clonedArg); } + builder->addInst(clonedInst); + context->builder = oldBuilder; + cloneDecorations(context, clonedInst, originalInst); } break; @@ -3524,12 +3549,15 @@ namespace Slang IRSpecContextBase* context, IRWitnessTable* originalTable, IROriginalValuesForClone const& originalValues, - IRWitnessTable* dstTable = nullptr) + IRWitnessTable* dstTable = nullptr, + bool registerValue = true) { auto clonedTable = dstTable ? dstTable : context->builder->createWitnessTable(); - registerClonedValue(context, clonedTable, originalValues); + if (registerValue) + registerClonedValue(context, clonedTable, originalValues); auto mangledName = originalTable->mangledName; + clonedTable->mangledName = mangledName; clonedTable->genericDecl = originalTable->genericDecl; clonedTable->subTypeDeclRef = originalTable->subTypeDeclRef; @@ -3539,8 +3567,11 @@ namespace Slang // Clone the entries in the witness table as well for( auto originalEntry : originalTable->entries ) { - auto clonedKey = context->maybeCloneValue(originalEntry->requirementKey.usedValue); - auto clonedVal = context->maybeCloneValue(originalEntry->satisfyingVal.usedValue); + auto clonedKey = cloneValue(context, originalEntry->requirementKey.usedValue); + + // if a global val with the mangled name already exists, don't clone again + auto clonedVal = maybeCloneValueWithMangledName(context, (IRGlobalValue*)(originalEntry->satisfyingVal.usedValue)); + /*auto clonedEntry = */context->builder->createWitnessTableEntry( clonedTable, clonedKey, @@ -3555,7 +3586,7 @@ namespace Slang IRWitnessTable* originalTable, IRWitnessTable* dstTable = nullptr) { - return cloneWitnessTableImpl(context, originalTable, IROriginalValuesForClone(), dstTable); + return cloneWitnessTableImpl(context, originalTable, IROriginalValuesForClone(), dstTable, false); } void cloneGlobalValueWithCodeCommon( @@ -3690,14 +3721,6 @@ namespace Slang // and their instructions. cloneFunctionCommon(context, clonedFunc, originalFunc); - // for now, clone all unreferenced witness tables - for (auto gv = context->getOriginalModule()->getFirstGlobalValue(); - gv; gv = gv->getNextValue()) - { - if (gv->op == kIROp_witness_table) - cloneGlobalValue(context, (IRWitnessTable*)gv); - } - // We need to attach the layout information for // the entry point to this declaration, so that // we can use it to inform downstream code emit. @@ -3746,7 +3769,7 @@ namespace Slang // TODO: We shouldn't be using strings for this. String getTargetName(IRSpecContext* context) { - switch( context->target ) + switch( context->shared->target ) { case CodeGenTarget::HLSL: return "hlsl"; @@ -4035,7 +4058,8 @@ namespace Slang IRSharedSpecContext* sharedContext, Session* session, IRModule* module, - IRModule* originalModule) + IRModule* originalModule, + CodeGenTarget target) { SharedIRBuilder* sharedBuilder = &sharedContext->sharedBuilderStorage; @@ -4053,7 +4077,7 @@ namespace Slang sharedContext->module = module; sharedContext->originalModule = originalModule; - + sharedContext->target = target; // We will populate a map with all of the IR values // that use the same mangled name, to make lookup easier // in other steps. @@ -4110,7 +4134,9 @@ namespace Slang sharedContext, compileRequest->mSession, nullptr, - originalIRModule); + originalIRModule, + target); + state->irModule = sharedContext->module; // We also need to attach the IR definitions for symbols from @@ -4123,7 +4149,6 @@ namespace Slang auto context = state->getContext(); context->shared = sharedContext; context->builder = &sharedContext->builderStorage; - context->target = target; // Create the GlobalGenericParamSubstitution for substituting global generic types // into user-provided type arguments @@ -4146,6 +4171,12 @@ namespace Slang context->globalVarLayouts.AddIfNotExists(mangledName, globalVarLayout); } + // for now, clone all unreferenced witness tables + for (auto sym :context->getSymbols()) + { + if (sym.Value->irGlobalValue->op == kIROp_witness_table) + cloneGlobalValue(context, (IRWitnessTable*)sym.Value->irGlobalValue); + } return state; } @@ -4263,7 +4294,31 @@ namespace Slang return symbol->irGlobalValue; } else - return nullptr; + { + // we don't have the required witness table yet, + // try to emit a specialize instruction to get one + auto subDeclRef = subtypeWitness->sub->AsDeclRefType(); + auto subDeclRefGen = DeclRef<Decl>(subDeclRef->declRef.decl, + createDefaultSubstitutions(context->builder->getSession(), subDeclRef->declRef.decl)); + + String genericName = getMangledNameForConformanceWitness( + subDeclRefGen, + subtypeWitness->sup); + if (context->getSymbols().TryGetValue(genericName, symbol)) + { + auto specInst = context->builder->emitSpecializeInst(subtypeWitness->sup, symbol->irGlobalValue, subDeclRef->declRef); + return specInst; + } + else + { + SLANG_UNEXPECTED("witness table not exist"); + UNREACHABLE_RETURN(nullptr); + } + } + } + else if (auto intVal = dynamic_cast<ConstantIntVal*>(val)) + { + return context->builder->getIntValue(context->shared->originalModule->session->getBuiltinType(BaseType::Int), intVal->value); } else if (auto proxyVal = dynamic_cast<IRProxyVal*>(val)) { @@ -4321,10 +4376,34 @@ namespace Slang return getIRValue(context, subst->args[argIndex]); } + else if (auto valDeclRef = declRef.As<GenericValueParamDecl>()) + { + // We have a constraint, but we need to find its index in the + // argument list of the substitutions. + UInt argIdx = 0; + bool found = false; + for (auto cd : genericDecl->Members) + { + if (cd.Ptr() == valDeclRef.getDecl()) + { + found = true; + break; + } + if (cd.As<GenericTypeParamDecl>()) + argIdx++; + else if (cd.As<GenericValueParamDecl>()) + argIdx++; + } + assert(found); + + assert(argIdx < subst->args.Count()); + + return getIRValue(context, subst->args[argIdx]); + } else { - SLANG_UNEXPECTED("unhandled case"); - return nullptr; + SLANG_UNEXPECTED("unimplemented"); + UNREACHABLE_RETURN(nullptr); } } @@ -4342,12 +4421,13 @@ namespace Slang // of the generic we are specializing, and in that case // we nee to translate it over to the equiavalent of // the `Val` we have been given. - if(declRef.getDecl()->ParentDecl == genSubst->genericDecl) + if(declRef.getDecl()->ParentDecl == genSubst->genericDecl && + (declRef.As<GenericTypeParamDecl>() || declRef.As<GenericValueParamDecl>()|| + declRef.As<GenericTypeConstraintDecl>())) { if (auto substVal = getSubstValue(this, declRef)) return substVal; } - int diff = 0; auto substDeclRef = declRefVal->declRef.SubstituteImpl(subst, &diff); if(!diff) @@ -4455,7 +4535,6 @@ namespace Slang // has already been made. To do that we will need to // compute the mangled name of the specialized function, // so that we can look for existing declarations. - String specMangledName; String specializedMangledName = getMangledNameForConformanceWitness(specDeclRef.Substitute(originalTable->subTypeDeclRef), specDeclRef.Substitute(originalTable->supTypeDeclRef)); @@ -4466,13 +4545,15 @@ namespace Slang // avoid it by building a dictionary ahead of time, // as is being done for the `IRSpecContext` used above. // We can probalby use the same basic context, actually. - auto module = originalTable->parentModule; - for (auto gv = module->getFirstGlobalValue(); gv; gv = gv->getNextValue()) + if (!dstTable) { - if (gv->mangledName == specMangledName) - return (IRWitnessTable*)gv; + auto module = sharedContext->module; + for (auto gv = module->getFirstGlobalValue(); gv; gv = gv->getNextValue()) + { + if (gv->mangledName == specializedMangledName) + return (IRWitnessTable*)gv; + } } - RefPtr<GenericSubstitution> newSubst = cloneSubstitutionsForSpecialization( sharedContext, specDeclRef.substitutions.genericSubstitutions, @@ -4483,13 +4564,12 @@ namespace Slang context.builder = &sharedContext->builderStorage; context.subst = specDeclRef.substitutions; context.subst.genericSubstitutions = newSubst; - // TODO: other initialization is needed here... auto specTable = cloneWitnessTableWithoutRegistering(&context, originalTable, dstTable); // Set up the clone to recognize that it is no longer generic - specTable->mangledName = specMangledName; + specTable->mangledName = specializedMangledName; specTable->genericDecl = nullptr; // Specialization of witness tables should trigger cascading specializations @@ -4499,8 +4579,9 @@ namespace Slang if (entry->satisfyingVal.usedValue->op == kIROp_Func) { IRFunc* func = (IRFunc*)entry->satisfyingVal.usedValue; - if (func->getGenericDecl()) - entry->satisfyingVal.set(getSpecializedFunc(sharedContext, func, specDeclRef)); + auto specFunc = getSpecializedFunc(sharedContext, func, specDeclRef); + entry->satisfyingVal.set(specFunc); + insertGlobalValueSymbol(sharedContext, specFunc); } } @@ -4526,13 +4607,16 @@ namespace Slang specMangledName = getMangledName(specDeclRef); else specMangledName = mangleSpecializedFuncName(genericFunc->mangledName, specDeclRef.substitutions); - + RefPtr<IRSpecSymbol> symb; + if (sharedContext->symbols.TryGetValue(specMangledName, symb)) + { + return (IRFunc*)(symb->irGlobalValue); + } // TODO: This is a terrible linear search, and we should // avoid it by building a dictionary ahead of time, // as is being done for the `IRSpecContext` used above. // We can probalby use the same basic context, actually. - auto module = genericFunc->parentModule; - for (auto gv = module->getFirstGlobalValue(); gv; gv = gv->getNextValue()) + for (auto gv = sharedContext->module->getFirstGlobalValue(); gv; gv = gv->getNextValue()) { if (gv->mangledName == specMangledName) return (IRFunc*) gv; @@ -4639,7 +4723,8 @@ namespace Slang // are known, and specialize the callee based on those // known values. void specializeGenerics( - IRModule* module) + IRModule* module, + CodeGenTarget target) { IRSharedSpecContext sharedContextStorage; auto sharedContext = &sharedContextStorage; @@ -4648,7 +4733,8 @@ namespace Slang sharedContext, module->session, module, - module); + module, + target); // Our goal here is to find `specialize` instructions that // can be replaced with references to a suitably sepcialized @@ -4895,11 +4981,10 @@ namespace Slang table = findWitnessTableByName(genericWitnessTableName); SLANG_ASSERT(table); WitnessTableSpecializationWorkItem workItem; - workItem.srcTable = (IRWitnessTable*)table; + workItem.srcTable = (IRWitnessTable*)cloneGlobalValue(context, (IRWitnessTable*)(table)); workItem.dstTable = context->builder->createWitnessTable(); workItem.dstTable->mangledName = getMangledNameForConformanceWitness(subDeclRefType->declRef, subtypeWitness->sup); workItem.specDeclRef = subDeclRefType->declRef; - registerClonedValue(context, workItem.dstTable, workItem.srcTable); witnessTablesToSpecailize.Add(workItem); table = workItem.dstTable; } diff --git a/source/slang/legalize-types.cpp b/source/slang/legalize-types.cpp index 211685aa2..d1cef4dac 100644 --- a/source/slang/legalize-types.cpp +++ b/source/slang/legalize-types.cpp @@ -916,7 +916,7 @@ LegalType legalizeType( } legalType = builder.getResult(); - context->mapDeclRefToLegalType.Add(declRef, legalType); + context->mapDeclRefToLegalType.AddIfNotExists(declRef, legalType); return legalType; } diff --git a/source/slang/lower-to-ir.cpp b/source/slang/lower-to-ir.cpp index 5e7e05a23..5d710725a 100644 --- a/source/slang/lower-to-ir.cpp +++ b/source/slang/lower-to-ir.cpp @@ -1021,6 +1021,7 @@ RefPtr<IRFuncType> getFuncType( return funcType; } +SubstitutionSet lowerSubstitutions(IRGenContext* context, SubstitutionSet subst); // struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, LoweredTypeInfo> @@ -1080,8 +1081,6 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower // TODO: actually test what module the type is coming from. lowerDecl(context, type->declRef); - - return LoweredTypeInfo(type); } @@ -3006,6 +3005,11 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> return globalVal; } + LoweredValInfo visitGenericValueParamDecl(GenericValueParamDecl* decl) + { + return LoweredValInfo::simple(context->irBuilder->getDeclRefVal(DeclRefBase(decl))); + } + LoweredValInfo visitVarDeclBase(VarDeclBase* decl) { // Detect global (or effectively global) variables @@ -3733,7 +3737,10 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> if (auto innerFuncDecl = genDecl->inner->As<FuncDecl>()) return lowerFuncDecl(innerFuncDecl); else if (auto innerStructDecl = genDecl->inner->As<StructDecl>()) + { + visitAggTypeDecl(innerStructDecl); return LoweredValInfo(); + } SLANG_RELEASE_ASSERT(false); UNREACHABLE_RETURN(LoweredValInfo()); } @@ -3910,6 +3917,32 @@ RefPtr<GenericSubstitution> lowerGenericSubstitutions( return result; } +RefPtr<GlobalGenericParamSubstitution> lowerGlobalGenericSubstitutions( + IRGenContext* context, + GlobalGenericParamSubstitution* genSubst) +{ + if (!genSubst) + return nullptr; + RefPtr<GlobalGenericParamSubstitution> result; + RefPtr<GlobalGenericParamSubstitution> newSubst = new GlobalGenericParamSubstitution(); + newSubst->actualType = lowerSubstitutionArg(context, genSubst->actualType); + newSubst->paramDecl = genSubst->paramDecl; + for (auto & tbl : genSubst->witnessTables) + { + auto ntbl = tbl; + ntbl.Value = lowerSubstitutionArg(context, tbl.Value); + newSubst->witnessTables.Add(ntbl); + } + result = newSubst; + if (genSubst->outer) + { + result->outer = lowerGlobalGenericSubstitutions( + context, + genSubst->outer); + } + return result; +} + RefPtr<ThisTypeSubstitution> lowerThisTypeSubstitution( IRGenContext* context, ThisTypeSubstitution* thisSubst) @@ -3926,7 +3959,7 @@ SubstitutionSet lowerSubstitutions(IRGenContext* context, SubstitutionSet subst) SubstitutionSet rs; rs.genericSubstitutions = lowerGenericSubstitutions(context, subst.genericSubstitutions); rs.thisTypeSubstitution = lowerThisTypeSubstitution(context, subst.thisTypeSubstitution); - rs.globalGenParamSubstitutions = subst.globalGenParamSubstitutions; + rs.globalGenParamSubstitutions = lowerGlobalGenericSubstitutions(context, subst.globalGenParamSubstitutions); return rs; } @@ -3973,10 +4006,10 @@ LoweredValInfo maybeEmitSpecializeInst(IRGenContext* context, // need to walk through those and replace things in // cases where the `Val`s used for substitution should // lower to something other than their original form. - auto lowedNewSubst = lowerGenericSubstitutions(context, newSubst); - DeclRef<Decl> newDeclRef = DeclRef<Decl>(declRef.decl, - SubstitutionSet(lowedNewSubst, declRef.substitutions.thisTypeSubstitution, - declRef.substitutions.globalGenParamSubstitutions)); + SubstitutionSet oldSubst = declRef.substitutions; + oldSubst.genericSubstitutions = newSubst; + auto lowedNewSubst = lowerSubstitutions(context, oldSubst); + DeclRef<Decl> newDeclRef = DeclRef<Decl>(declRef.decl, lowedNewSubst); RefPtr<Type> type; if (auto declType = val->getType()) @@ -4014,9 +4047,9 @@ static void lowerEntryPointToIR( return; } // we need to lower all global type arguments as well + auto loweredEntryPointFunc = ensureDecl(context, entryPointFuncDecl); for (auto arg : entryPointRequest->genericParameterTypes) lowerType(context, arg); - auto loweredEntryPointFunc = ensureDecl(context, entryPointFuncDecl); } #if 0 diff --git a/source/slang/parameter-binding.cpp b/source/slang/parameter-binding.cpp index e5ea1d531..e1c5c1aca 100644 --- a/source/slang/parameter-binding.cpp +++ b/source/slang/parameter-binding.cpp @@ -2109,14 +2109,15 @@ RefPtr<ProgramLayout> specializeProgramLayout( auto constantBufferRules = context.getRulesFamily()->getConstantBufferRules(); structLayout->rules = constantBufferRules; - + structLayout->fields.SetSize(globalStructLayout->fields.Count()); UniformLayoutInfo structLayoutInfo; structLayoutInfo.alignment = globalStructLayout->uniformAlignment; structLayoutInfo.size = 0; bool anyUniforms = false; Dictionary<RefPtr<VarLayout>, RefPtr<VarLayout>> varLayoutMapping; - for (auto & varLayout : globalStructLayout->fields) + for (uint32_t varId = 0; varId < globalStructLayout->fields.Count(); varId++) { + auto &varLayout = globalStructLayout->fields[varId]; // To recover layout context, we skip generic resources in the first pass if (varLayout->FindResourceInfo(LayoutResourceKind::GenericResource)) continue; @@ -2141,7 +2142,7 @@ RefPtr<ProgramLayout> specializeProgramLayout( resInfo.index, resInfo.index + tresInfo.count); } - structLayout->fields.Add(varLayout); + structLayout->fields[varId] = varLayout; varLayoutMapping[varLayout] = varLayout; } auto originalGlobalCBufferInfo = programLayout->globalScopeLayout->FindResourceInfo(LayoutResourceKind::ConstantBuffer); @@ -2156,8 +2157,9 @@ RefPtr<ProgramLayout> specializeProgramLayout( globalCBufferInfo.index = originalGlobalCBufferInfo->index; } // we have the context restored, can continue to layout the generic variables now - for (auto & varLayout : globalStructLayout->fields) + for (uint32_t varId = 0; varId < globalStructLayout->fields.Count(); varId++) { + auto &varLayout = globalStructLayout->fields[varId]; if (varLayout->typeLayout->FindResourceInfo(LayoutResourceKind::GenericResource)) { RefPtr<Type> newType = varLayout->typeLayout->type->Substitute(typeSubst).As<Type>(); @@ -2202,7 +2204,7 @@ RefPtr<ProgramLayout> specializeProgramLayout( newVarLayout->findOrAddResourceInfo(LayoutResourceKind::Uniform)->index = uniformOffset; anyUniforms = true; } - structLayout->fields.Add(newVarLayout); + structLayout->fields[varId] = newVarLayout; varLayoutMapping[varLayout] = newVarLayout; } } diff --git a/source/slang/parser.cpp b/source/slang/parser.cpp index 7e36b0e71..531606f8d 100644 --- a/source/slang/parser.cpp +++ b/source/slang/parser.cpp @@ -2704,6 +2704,7 @@ namespace Slang PushScope(program); program->loc = tokenReader.PeekLoc(); + program->scope = currentScope; ParseDeclBody(this, program, TokenType::EndOfFile); PopScope(); @@ -3960,6 +3961,17 @@ namespace Slang return parsePrefixExpr(this); } + RefPtr<Expr> parseTypeFromSourceFile(TranslationUnitRequest* translationUnit, + TokenSpan const& tokens, + DiagnosticSink* sink, + RefPtr<Scope> const& outerScope) + { + Parser parser(tokens, sink, outerScope); + parser.translationUnit = translationUnit; + parser.currentScope = outerScope; + return parser.ParseType(); + } + // Parse a source file into an existing translation unit void parseSourceFile( TranslationUnitRequest* translationUnit, @@ -3971,6 +3983,7 @@ namespace Slang parser.translationUnit = translationUnit; + return parser.parseSourceFile(translationUnit->SyntaxNode.Ptr()); } diff --git a/source/slang/parser.h b/source/slang/parser.h index 60fe4b3ae..785b6e345 100644 --- a/source/slang/parser.h +++ b/source/slang/parser.h @@ -14,6 +14,11 @@ namespace Slang DiagnosticSink* sink, RefPtr<Scope> const& outerScope); + RefPtr<Expr> parseTypeFromSourceFile(TranslationUnitRequest* translationUnit, + TokenSpan const& tokens, + DiagnosticSink* sink, + RefPtr<Scope> const& outerScope); + RefPtr<ModuleDecl> populateBaseLanguageModule( Session* session, RefPtr<Scope> scope); diff --git a/source/slang/reflection.cpp b/source/slang/reflection.cpp index c9de75d6e..b0be58274 100644 --- a/source/slang/reflection.cpp +++ b/source/slang/reflection.cpp @@ -433,20 +433,8 @@ SLANG_API SlangReflectionType * spReflection_FindTypeByName(SlangReflection * re auto context = convert(reflection); auto compileRequest = context->targetRequest->compileRequest; - RefPtr<Type> result; - if (compileRequest->types.TryGetValue(name, result)) - return (SlangReflectionType*)result.Ptr(); - - auto nameObj = compileRequest->getNamePool()->getName(name); - Decl* resultDecl = compileRequest->lookupGlobalDecl(nameObj); - if (resultDecl) - { - RefPtr<DeclRefType> declRefType = new DeclRefType(); - declRefType->declRef.decl = resultDecl; - compileRequest->types[name] = declRefType; - return (SlangReflectionType*)declRefType.Ptr(); - } - return nullptr; + RefPtr<Type> result = compileRequest->getTypeFromString(name); + return (SlangReflectionType*)result.Ptr(); } SLANG_API SlangReflectionTypeLayout* spReflection_GetTypeLayout( diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index 2ebf024e3..4c9ecf8a8 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -125,6 +125,50 @@ CompileRequest::CompileRequest(Session* session) CompileRequest::~CompileRequest() {} + +RefPtr<Expr> CompileRequest::parseTypeString(TranslationUnitRequest * translationUnit, String typeStr, RefPtr<Scope> scope) +{ + Slang::SourceFile srcFile; + srcFile.content = typeStr; + DiagnosticSink sink; + sink.sourceManager = sourceManager; + auto tokens = preprocessSource( + &srcFile, + &sink, + nullptr, + Dictionary<String,String>(), + translationUnit); + return parseTypeFromSourceFile(translationUnit, tokens, &sink, scope); +} + +RefPtr<Type> checkProperType(TranslationUnitRequest * tu, TypeExp typeExp); +Type* CompileRequest::getTypeFromString(String typeStr) +{ + RefPtr<Type> type; + if (types.TryGetValue(typeStr, type)) + return type; + auto translationUnit = translationUnits.First(); + List<RefPtr<Scope>> scopesToTry; + for (auto tu : translationUnits) + scopesToTry.Add(tu->SyntaxNode->scope); + for (auto & module : loadedModulesList) + scopesToTry.Add(module->moduleDecl->scope); + // parse type name + for (auto & s : scopesToTry) + { + RefPtr<Expr> typeExpr = parseTypeString(translationUnit, + typeStr, s); + type = checkProperType(translationUnit, TypeExp(typeExpr)); + if (type) + break; + } + if (type) + { + types[typeStr] = type; + } + return type.Ptr(); +} + void CompileRequest::parseTranslationUnit( TranslationUnitRequest* translationUnit) { @@ -429,7 +473,7 @@ int CompileRequest::addEntryPoint( entryPoint->profile = entryPointProfile; entryPoint->translationUnitIndex = translationUnitIndex; for (auto typeName : genericTypeNames) - entryPoint->genericParameterTypeNames.Add(getNamePool()->getName(typeName)); + entryPoint->genericParameterTypeNames.Add(typeName); auto translationUnit = translationUnits[translationUnitIndex].Ptr(); translationUnit->entryPoints.Add(entryPoint); diff --git a/source/slang/syntax.cpp b/source/slang/syntax.cpp index 552f1dc56..ab4a5f94c 100644 --- a/source/slang/syntax.cpp +++ b/source/slang/syntax.cpp @@ -355,19 +355,21 @@ void Type::accept(IValVisitor* visitor, void* extra) auto arrType = type->AsArrayType(); if (!arrType) return false; - return (ArrayLength == arrType->ArrayLength && baseType->Equals(arrType->baseType.Ptr())); + return (ArrayLength->EqualsVal(arrType->ArrayLength) && baseType->Equals(arrType->baseType.Ptr())); } RefPtr<Val> ArrayExpressionType::SubstituteImpl(SubstitutionSet subst, int* ioDiff) { int diff = 0; auto elementType = baseType->SubstituteImpl(subst, &diff).As<Type>(); + auto arrlen = ArrayLength->SubstituteImpl(subst, &diff).As<IntVal>(); + SLANG_ASSERT(arrlen); if (diff) { *ioDiff = 1; auto rsType = getArrayType( elementType, - ArrayLength); + arrlen); return rsType; } return this; diff --git a/tests/compute/array-param.slang b/tests/compute/array-param.slang new file mode 100644 index 000000000..78ca52518 --- /dev/null +++ b/tests/compute/array-param.slang @@ -0,0 +1,19 @@ +//TEST(compute):COMPARE_COMPUTE:-xslang -use-ir +//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):dxbinding(0),glbinding(0),out + +RWStructuredBuffer<int> outputBuffer; +void writeArray(inout float3 a[4]) +{ + a[0] = float3(1, 1, 1); + a[1] = float3(1, 1, 1); + a[2] = float3(1, 1, 1); + a[3] = float3(1, 1, 1); +} + +[numthreads(4, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + float3 b[4]; + writeArray(b); + outputBuffer[dispatchThreadID.x] = b[0].x; +}
\ No newline at end of file diff --git a/tests/compute/array-param.slang.expected.txt b/tests/compute/array-param.slang.expected.txt new file mode 100644 index 000000000..ef529012e --- /dev/null +++ b/tests/compute/array-param.slang.expected.txt @@ -0,0 +1,4 @@ +1 +1 +1 +1
\ No newline at end of file diff --git a/tests/compute/global-type-param3.slang b/tests/compute/global-type-param-array.slang index 05793dce4..74e52d5d4 100644 --- a/tests/compute/global-type-param3.slang +++ b/tests/compute/global-type-param-array.slang @@ -1,23 +1,10 @@ -//TEST(smoke,compute):COMPARE_COMPUTE:-xslang -use-ir -//TEST_INPUT: cbuffer(data=[1.0], stride=4):dxbinding(0),glbinding(0) +//TEST(compute):COMPARE_COMPUTE:-xslang -use-ir +//TEST_INPUT: cbuffer(data=[1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0], stride=4):dxbinding(0),glbinding(0) //TEST_INPUT: ubuffer(data=[0], stride=4):dxbinding(0),glbinding(0),out -//TEST_INPUT: type Impl +//TEST_INPUT: type Pair<Arr<Base,1>, Pair<Arr<Base,2> , Base> > RWStructuredBuffer<float> outputBuffer; - -interface IBase -{ - float compute(); -} - -struct Impl : IBase -{ - float base; // = 1.0 - float compute() - { - return 1.0; - } -}; +import globalTypeParamArrayShared; __generic_param TImpl : IBase; @@ -25,7 +12,7 @@ ParameterBlock<TImpl> impl; float doCompute<T:IBase>(T t) { - return t.compute(); + return t.compute(1.0); } [numthreads(1, 1, 1)] diff --git a/tests/compute/global-type-param-array.slang.expected.txt b/tests/compute/global-type-param-array.slang.expected.txt new file mode 100644 index 000000000..bdf6b77dc --- /dev/null +++ b/tests/compute/global-type-param-array.slang.expected.txt @@ -0,0 +1 @@ +40800000 diff --git a/tests/compute/global-type-param.slang b/tests/compute/global-type-param.slang index 301ef1021..03f5df329 100644 --- a/tests/compute/global-type-param.slang +++ b/tests/compute/global-type-param.slang @@ -1,6 +1,6 @@ //TEST(smoke,compute):COMPARE_COMPUTE:-xslang -use-ir //TEST_INPUT:ubuffer(data=[0], stride=4):dxbinding(0),glbinding(0),out -//TEST_INPUT:type Impl +//TEST_INPUT:type Wrapper<Impl> RWStructuredBuffer<float> outputBuffer; @@ -9,6 +9,15 @@ interface IBase float compute(); } +struct Wrapper<T : IBase> : IBase +{ + T obj; + float compute() + { + return obj.compute(); + } +}; + struct Impl : IBase { float compute() diff --git a/tests/compute/global-type-param3.slang.expected.txt b/tests/compute/global-type-param3.slang.expected.txt deleted file mode 100644 index deb1c3630..000000000 --- a/tests/compute/global-type-param3.slang.expected.txt +++ /dev/null @@ -1 +0,0 @@ -3F800000 diff --git a/tests/compute/globalTypeParamArrayShared.slang b/tests/compute/globalTypeParamArrayShared.slang new file mode 100644 index 000000000..ee3caa372 --- /dev/null +++ b/tests/compute/globalTypeParamArrayShared.slang @@ -0,0 +1,32 @@ +//TEST_IGNORE_FILE: +interface IBase +{ + float compute<T>(T g); +} +struct Base:IBase +{ + float b; + float compute<T>(T g) { return b; } +}; + +struct Pair<T1:IBase, T2:IBase> : IBase +{ + T1 head; + T2 tail; + float compute<T>(T g) + { + return head.compute(g) + tail.compute(g); + } +}; + +struct Arr<T:IBase, let N:int> : IBase +{ + T base[N]; // = 1.0 + float compute<T>(T g) + { + float sum = 0.0; + for (int i = 0; i < N; i++) + sum += base[i].compute(g); + return sum; + } +}; diff --git a/tools/render-test/shader-input-layout.cpp b/tools/render-test/shader-input-layout.cpp index 01328eabd..fcf25f376 100644 --- a/tools/render-test/shader-input-layout.cpp +++ b/tools/render-test/shader-input-layout.cpp @@ -20,7 +20,10 @@ namespace renderer_test if (parser.LookAhead("type")) { parser.ReadToken(); - globalTypeArguments.Add(parser.ReadWord()); + StringBuilder typeExp; + while (!parser.IsEnd()) + typeExp << parser.ReadToken().Content; + globalTypeArguments.Add(typeExp); } else { diff --git a/tools/render-test/slang-support.cpp b/tools/render-test/slang-support.cpp index cfbc24382..9263aa41b 100644 --- a/tools/render-test/slang-support.cpp +++ b/tools/render-test/slang-support.cpp @@ -100,6 +100,7 @@ struct SlangShaderCompilerWrapper : public ShaderCompiler (int)rawTypeNames.Count(), rawTypeNames.Buffer()); int compileErr = spCompile(slangRequest); + spSetLineDirectiveMode(slangRequest, SLANG_LINE_DIRECTIVE_MODE_NONE); if (auto diagnostics = spGetDiagnosticOutput(slangRequest)) { fprintf(stderr, "%s", diagnostics); |
