summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2018-01-21 10:48:31 -0800
committerGitHub <noreply@github.com>2018-01-21 10:48:31 -0800
commit4044a1d3a0605198465a7eb6e0e3c1f8b1a3c298 (patch)
tree62927d4d2722b36c8e7eb4060e741b9032686835
parent2079b941bc5849b6ab33774fb90cefe9c2d624cb (diff)
parentf681a1505c98995683a7fbae7ce208dc5e444b9b (diff)
Merge pull request #372 from csyonghe/master
Allow type expression as type argument, fix global param enum order
-rw-r--r--source/slang/check.cpp48
-rw-r--r--source/slang/compiler.h6
-rw-r--r--source/slang/decl-defs.h4
-rw-r--r--source/slang/emit.cpp2
-rw-r--r--source/slang/ir-insts.h3
-rw-r--r--source/slang/ir.cpp193
-rw-r--r--source/slang/legalize-types.cpp2
-rw-r--r--source/slang/lower-to-ir.cpp49
-rw-r--r--source/slang/parameter-binding.cpp12
-rw-r--r--source/slang/parser.cpp13
-rw-r--r--source/slang/parser.h5
-rw-r--r--source/slang/reflection.cpp16
-rw-r--r--source/slang/slang.cpp46
-rw-r--r--source/slang/syntax.cpp6
-rw-r--r--tests/compute/array-param.slang19
-rw-r--r--tests/compute/array-param.slang.expected.txt4
-rw-r--r--tests/compute/global-type-param-array.slang (renamed from tests/compute/global-type-param3.slang)23
-rw-r--r--tests/compute/global-type-param-array.slang.expected.txt1
-rw-r--r--tests/compute/global-type-param.slang11
-rw-r--r--tests/compute/global-type-param3.slang.expected.txt1
-rw-r--r--tests/compute/globalTypeParamArrayShared.slang32
-rw-r--r--tools/render-test/shader-input-layout.cpp5
-rw-r--r--tools/render-test/slang-support.cpp1
23 files changed, 376 insertions, 126 deletions
diff --git a/source/slang/check.cpp b/source/slang/check.cpp
index dfc09c485..52558ee15 100644
--- a/source/slang/check.cpp
+++ b/source/slang/check.cpp
@@ -6849,6 +6849,23 @@ namespace Slang
return (!decl->primaryDecl) || (decl == decl->primaryDecl);
}
+ RefPtr<Type> checkProperType(TranslationUnitRequest * tu, TypeExp typeExp)
+ {
+ RefPtr<Type> type;
+ DiagnosticSink nSink;
+ nSink.sourceManager = tu->compileRequest->sourceManager;
+ SemanticsVisitor visitor(
+ &nSink,
+ tu->compileRequest,
+ tu);
+ auto typeOut = visitor.CheckProperType(typeExp);
+ if (!nSink.errorCount)
+ {
+ type = typeOut.type;
+ }
+ return type;
+ }
+
void validateEntryPoint(
EntryPointRequest* entryPoint)
{
@@ -6944,26 +6961,25 @@ namespace Slang
entryPoint->decl = entryPointFuncDecl;
// Lookup generic parameter types in global scope
+ List<RefPtr<Scope>> scopesToTry;
+ scopesToTry.Add(entryPoint->getTranslationUnit()->SyntaxNode->scope);
+ for (auto & module : entryPoint->compileRequest->loadedModulesList)
+ scopesToTry.Add(module->moduleDecl->scope);
for (auto name : entryPoint->genericParameterTypeNames)
- {
- firstDeclWithName = entryPoint->compileRequest->lookupGlobalDecl(name);
- if (!firstDeclWithName)
- {
- // If there doesn't appear to be any such declaration, then
- // we need to diagnose it as an error, and then bail out.
- sink->diagnose(translationUnitSyntax, Diagnostics::entryPointTypeParameterNotFound, name);
- return;
- }
+ {
+ // parse type name
RefPtr<Type> type;
- if (auto aggType = firstDeclWithName->As<AggTypeDecl>())
- {
- type = DeclRefType::Create(entryPoint->compileRequest->mSession, DeclRef<Decl>(aggType, nullptr));
- }
- else if (auto typeDefDecl = firstDeclWithName->As<TypeDefDecl>())
+ for (auto & s : scopesToTry)
{
- type = GetType(DeclRef<TypeDefDecl>(typeDefDecl, nullptr));
+ RefPtr<Expr> typeExpr = entryPoint->compileRequest->parseTypeString(entryPoint->getTranslationUnit(),
+ name, s);
+ type = checkProperType(translationUnit, TypeExp(typeExpr));
+ if (type)
+ {
+ break;
+ }
}
- else
+ if (!type)
{
sink->diagnose(firstDeclWithName, Diagnostics::entryPointTypeSymbolNotAType, name);
return;
diff --git a/source/slang/compiler.h b/source/slang/compiler.h
index 1fca4751c..960e67ffe 100644
--- a/source/slang/compiler.h
+++ b/source/slang/compiler.h
@@ -104,7 +104,7 @@ namespace Slang
// The type names we want to substitute into the
// global generic type parameters
- List<Name*> genericParameterTypeNames;
+ List<String> genericParameterTypeNames;
// The profile that the entry point will be compiled for
// (this is a combination of the target state, and also
@@ -318,6 +318,10 @@ namespace Slang
~CompileRequest();
+ RefPtr<Expr> parseTypeString(TranslationUnitRequest * translationUnit, String typeStr, RefPtr<Scope> scope);
+
+ Type* getTypeFromString(String typeStr);
+
void parseTranslationUnit(
TranslationUnitRequest* translationUnit);
diff --git a/source/slang/decl-defs.h b/source/slang/decl-defs.h
index fb35e327a..8e1985e3f 100644
--- a/source/slang/decl-defs.h
+++ b/source/slang/decl-defs.h
@@ -196,7 +196,9 @@ SIMPLE_SYNTAX_CLASS(Variable, VarDeclBase);
// A "module" of code (essentiately, a single translation unit)
// that provides a scope for some number of declarations.
-SIMPLE_SYNTAX_CLASS(ModuleDecl, ContainerDecl)
+SYNTAX_CLASS(ModuleDecl, ContainerDecl)
+ FIELD(RefPtr<Scope>, scope)
+END_SYNTAX_CLASS()
SYNTAX_CLASS(ImportDecl, Decl)
// The name of the module we are trying to import
diff --git a/source/slang/emit.cpp b/source/slang/emit.cpp
index 53f02cc56..18216de81 100644
--- a/source/slang/emit.cpp
+++ b/source/slang/emit.cpp
@@ -7500,7 +7500,7 @@ String emitEntryPoint(
// none of our target supports generics, or interfaces,
// so we need to specialize those away.
//
- specializeGenerics(irModule);
+ specializeGenerics(irModule, sharedContext.target);
// Debugging code for IR transformations...
#if 0
diff --git a/source/slang/ir-insts.h b/source/slang/ir-insts.h
index 23e948b3a..dedc906d0 100644
--- a/source/slang/ir-insts.h
+++ b/source/slang/ir-insts.h
@@ -641,7 +641,8 @@ void specializeIRForEntryPoint(
// Find suitable uses of the `specialize` instruction that
// can be replaced with references to specialized functions.
void specializeGenerics(
- IRModule* module);
+ IRModule* module,
+ CodeGenTarget target);
//
diff --git a/source/slang/ir.cpp b/source/slang/ir.cpp
index 994ac82ff..7318bff4c 100644
--- a/source/slang/ir.cpp
+++ b/source/slang/ir.cpp
@@ -7,6 +7,13 @@
namespace Slang
{
+ struct IRSpecContext;
+
+ IRGlobalValue* cloneGlobalValueWithMangledName(
+ IRSpecContext* context,
+ String const& mangledName,
+ IRGlobalValue* originalVal);
+
static const IROpInfo kIROpInfos[] =
{
@@ -3065,6 +3072,9 @@ namespace Slang
struct IRSharedSpecContext
{
+ // The code-generation target in use
+ CodeGenTarget target;
+
// The specialized module we are building
IRModule* module;
@@ -3091,6 +3101,10 @@ namespace Slang
struct IRSpecContextBase
{
+ // A map from the mangled name of a global variable
+ // to the layout to use for it.
+ Dictionary<String, VarLayout*> globalVarLayouts;
+
IRSharedSpecContext* shared;
IRSharedSpecContext* getShared() { return shared; }
@@ -3224,13 +3238,6 @@ namespace Slang
struct IRSpecContext : IRSpecContextBase
{
- // The code-generation target in use
- CodeGenTarget target;
-
- // A map from the mangled name of a global variable
- // to the layout to use for it.
- Dictionary<String, VarLayout*> globalVarLayouts;
-
// Override the "maybe clone" logic so that we always clone
virtual IRValue* maybeCloneValue(IRValue* originalVal) override;
@@ -3434,18 +3441,31 @@ namespace Slang
return newDeclRef;
}
-
IRValue* cloneValue(
IRSpecContextBase* context,
IRValue* originalValue)
{
IRValue* clonedValue = nullptr;
if (context->getClonedValues().TryGetValue(originalValue, clonedValue))
+ {
return clonedValue;
+ }
return context->maybeCloneValue(originalValue);
}
+ IRValue* maybeCloneValueWithMangledName(
+ IRSpecContextBase* context,
+ IRGlobalValue* originalValue)
+ {
+ for (auto gv = context->shared->module->firstGlobalValue; gv; gv = gv->nextGlobalValue)
+ {
+ if (gv->mangledName == originalValue->mangledName)
+ return gv;
+ }
+ return cloneValue(context, originalValue);
+ }
+
void cloneInst(
IRSpecContextBase* context,
IRBuilder* builder,
@@ -3468,18 +3488,23 @@ namespace Slang
context->maybeCloneType(originalInst->type),
0, nullptr,
argCount, nullptr);
- builder->addInst(clonedInst);
registerClonedValue(context, clonedInst, originalInst);
-
- cloneDecorations(context, clonedInst, originalInst);
-
+ auto oldBuilder = context->builder;
+ context->builder = builder;
for (UInt aa = 0; aa < argCount; ++aa)
{
IRValue* originalArg = originalInst->getArg(aa);
- IRValue* clonedArg = cloneValue(context, originalArg);
-
+ IRValue* clonedArg;
+ if (originalArg->op == kIROp_witness_table)
+ clonedArg = cloneGlobalValueWithMangledName((IRSpecContext*)context,
+ ((IRGlobalValue*)originalArg)->mangledName, (IRGlobalValue*)originalArg);
+ else
+ clonedArg = cloneValue(context, originalArg);
clonedInst->getArgs()[aa].init(clonedInst, clonedArg);
}
+ builder->addInst(clonedInst);
+ context->builder = oldBuilder;
+ cloneDecorations(context, clonedInst, originalInst);
}
break;
@@ -3524,12 +3549,15 @@ namespace Slang
IRSpecContextBase* context,
IRWitnessTable* originalTable,
IROriginalValuesForClone const& originalValues,
- IRWitnessTable* dstTable = nullptr)
+ IRWitnessTable* dstTable = nullptr,
+ bool registerValue = true)
{
auto clonedTable = dstTable ? dstTable : context->builder->createWitnessTable();
- registerClonedValue(context, clonedTable, originalValues);
+ if (registerValue)
+ registerClonedValue(context, clonedTable, originalValues);
auto mangledName = originalTable->mangledName;
+
clonedTable->mangledName = mangledName;
clonedTable->genericDecl = originalTable->genericDecl;
clonedTable->subTypeDeclRef = originalTable->subTypeDeclRef;
@@ -3539,8 +3567,11 @@ namespace Slang
// Clone the entries in the witness table as well
for( auto originalEntry : originalTable->entries )
{
- auto clonedKey = context->maybeCloneValue(originalEntry->requirementKey.usedValue);
- auto clonedVal = context->maybeCloneValue(originalEntry->satisfyingVal.usedValue);
+ auto clonedKey = cloneValue(context, originalEntry->requirementKey.usedValue);
+
+ // if a global val with the mangled name already exists, don't clone again
+ auto clonedVal = maybeCloneValueWithMangledName(context, (IRGlobalValue*)(originalEntry->satisfyingVal.usedValue));
+
/*auto clonedEntry = */context->builder->createWitnessTableEntry(
clonedTable,
clonedKey,
@@ -3555,7 +3586,7 @@ namespace Slang
IRWitnessTable* originalTable,
IRWitnessTable* dstTable = nullptr)
{
- return cloneWitnessTableImpl(context, originalTable, IROriginalValuesForClone(), dstTable);
+ return cloneWitnessTableImpl(context, originalTable, IROriginalValuesForClone(), dstTable, false);
}
void cloneGlobalValueWithCodeCommon(
@@ -3690,14 +3721,6 @@ namespace Slang
// and their instructions.
cloneFunctionCommon(context, clonedFunc, originalFunc);
- // for now, clone all unreferenced witness tables
- for (auto gv = context->getOriginalModule()->getFirstGlobalValue();
- gv; gv = gv->getNextValue())
- {
- if (gv->op == kIROp_witness_table)
- cloneGlobalValue(context, (IRWitnessTable*)gv);
- }
-
// We need to attach the layout information for
// the entry point to this declaration, so that
// we can use it to inform downstream code emit.
@@ -3746,7 +3769,7 @@ namespace Slang
// TODO: We shouldn't be using strings for this.
String getTargetName(IRSpecContext* context)
{
- switch( context->target )
+ switch( context->shared->target )
{
case CodeGenTarget::HLSL:
return "hlsl";
@@ -4035,7 +4058,8 @@ namespace Slang
IRSharedSpecContext* sharedContext,
Session* session,
IRModule* module,
- IRModule* originalModule)
+ IRModule* originalModule,
+ CodeGenTarget target)
{
SharedIRBuilder* sharedBuilder = &sharedContext->sharedBuilderStorage;
@@ -4053,7 +4077,7 @@ namespace Slang
sharedContext->module = module;
sharedContext->originalModule = originalModule;
-
+ sharedContext->target = target;
// We will populate a map with all of the IR values
// that use the same mangled name, to make lookup easier
// in other steps.
@@ -4110,7 +4134,9 @@ namespace Slang
sharedContext,
compileRequest->mSession,
nullptr,
- originalIRModule);
+ originalIRModule,
+ target);
+
state->irModule = sharedContext->module;
// We also need to attach the IR definitions for symbols from
@@ -4123,7 +4149,6 @@ namespace Slang
auto context = state->getContext();
context->shared = sharedContext;
context->builder = &sharedContext->builderStorage;
- context->target = target;
// Create the GlobalGenericParamSubstitution for substituting global generic types
// into user-provided type arguments
@@ -4146,6 +4171,12 @@ namespace Slang
context->globalVarLayouts.AddIfNotExists(mangledName, globalVarLayout);
}
+ // for now, clone all unreferenced witness tables
+ for (auto sym :context->getSymbols())
+ {
+ if (sym.Value->irGlobalValue->op == kIROp_witness_table)
+ cloneGlobalValue(context, (IRWitnessTable*)sym.Value->irGlobalValue);
+ }
return state;
}
@@ -4263,7 +4294,31 @@ namespace Slang
return symbol->irGlobalValue;
}
else
- return nullptr;
+ {
+ // we don't have the required witness table yet,
+ // try to emit a specialize instruction to get one
+ auto subDeclRef = subtypeWitness->sub->AsDeclRefType();
+ auto subDeclRefGen = DeclRef<Decl>(subDeclRef->declRef.decl,
+ createDefaultSubstitutions(context->builder->getSession(), subDeclRef->declRef.decl));
+
+ String genericName = getMangledNameForConformanceWitness(
+ subDeclRefGen,
+ subtypeWitness->sup);
+ if (context->getSymbols().TryGetValue(genericName, symbol))
+ {
+ auto specInst = context->builder->emitSpecializeInst(subtypeWitness->sup, symbol->irGlobalValue, subDeclRef->declRef);
+ return specInst;
+ }
+ else
+ {
+ SLANG_UNEXPECTED("witness table not exist");
+ UNREACHABLE_RETURN(nullptr);
+ }
+ }
+ }
+ else if (auto intVal = dynamic_cast<ConstantIntVal*>(val))
+ {
+ return context->builder->getIntValue(context->shared->originalModule->session->getBuiltinType(BaseType::Int), intVal->value);
}
else if (auto proxyVal = dynamic_cast<IRProxyVal*>(val))
{
@@ -4321,10 +4376,34 @@ namespace Slang
return getIRValue(context, subst->args[argIndex]);
}
+ else if (auto valDeclRef = declRef.As<GenericValueParamDecl>())
+ {
+ // We have a constraint, but we need to find its index in the
+ // argument list of the substitutions.
+ UInt argIdx = 0;
+ bool found = false;
+ for (auto cd : genericDecl->Members)
+ {
+ if (cd.Ptr() == valDeclRef.getDecl())
+ {
+ found = true;
+ break;
+ }
+ if (cd.As<GenericTypeParamDecl>())
+ argIdx++;
+ else if (cd.As<GenericValueParamDecl>())
+ argIdx++;
+ }
+ assert(found);
+
+ assert(argIdx < subst->args.Count());
+
+ return getIRValue(context, subst->args[argIdx]);
+ }
else
{
- SLANG_UNEXPECTED("unhandled case");
- return nullptr;
+ SLANG_UNEXPECTED("unimplemented");
+ UNREACHABLE_RETURN(nullptr);
}
}
@@ -4342,12 +4421,13 @@ namespace Slang
// of the generic we are specializing, and in that case
// we nee to translate it over to the equiavalent of
// the `Val` we have been given.
- if(declRef.getDecl()->ParentDecl == genSubst->genericDecl)
+ if(declRef.getDecl()->ParentDecl == genSubst->genericDecl &&
+ (declRef.As<GenericTypeParamDecl>() || declRef.As<GenericValueParamDecl>()||
+ declRef.As<GenericTypeConstraintDecl>()))
{
if (auto substVal = getSubstValue(this, declRef))
return substVal;
}
-
int diff = 0;
auto substDeclRef = declRefVal->declRef.SubstituteImpl(subst, &diff);
if(!diff)
@@ -4455,7 +4535,6 @@ namespace Slang
// has already been made. To do that we will need to
// compute the mangled name of the specialized function,
// so that we can look for existing declarations.
- String specMangledName;
String specializedMangledName = getMangledNameForConformanceWitness(specDeclRef.Substitute(originalTable->subTypeDeclRef),
specDeclRef.Substitute(originalTable->supTypeDeclRef));
@@ -4466,13 +4545,15 @@ namespace Slang
// avoid it by building a dictionary ahead of time,
// as is being done for the `IRSpecContext` used above.
// We can probalby use the same basic context, actually.
- auto module = originalTable->parentModule;
- for (auto gv = module->getFirstGlobalValue(); gv; gv = gv->getNextValue())
+ if (!dstTable)
{
- if (gv->mangledName == specMangledName)
- return (IRWitnessTable*)gv;
+ auto module = sharedContext->module;
+ for (auto gv = module->getFirstGlobalValue(); gv; gv = gv->getNextValue())
+ {
+ if (gv->mangledName == specializedMangledName)
+ return (IRWitnessTable*)gv;
+ }
}
-
RefPtr<GenericSubstitution> newSubst = cloneSubstitutionsForSpecialization(
sharedContext,
specDeclRef.substitutions.genericSubstitutions,
@@ -4483,13 +4564,12 @@ namespace Slang
context.builder = &sharedContext->builderStorage;
context.subst = specDeclRef.substitutions;
context.subst.genericSubstitutions = newSubst;
-
// TODO: other initialization is needed here...
auto specTable = cloneWitnessTableWithoutRegistering(&context, originalTable, dstTable);
// Set up the clone to recognize that it is no longer generic
- specTable->mangledName = specMangledName;
+ specTable->mangledName = specializedMangledName;
specTable->genericDecl = nullptr;
// Specialization of witness tables should trigger cascading specializations
@@ -4499,8 +4579,9 @@ namespace Slang
if (entry->satisfyingVal.usedValue->op == kIROp_Func)
{
IRFunc* func = (IRFunc*)entry->satisfyingVal.usedValue;
- if (func->getGenericDecl())
- entry->satisfyingVal.set(getSpecializedFunc(sharedContext, func, specDeclRef));
+ auto specFunc = getSpecializedFunc(sharedContext, func, specDeclRef);
+ entry->satisfyingVal.set(specFunc);
+ insertGlobalValueSymbol(sharedContext, specFunc);
}
}
@@ -4526,13 +4607,16 @@ namespace Slang
specMangledName = getMangledName(specDeclRef);
else
specMangledName = mangleSpecializedFuncName(genericFunc->mangledName, specDeclRef.substitutions);
-
+ RefPtr<IRSpecSymbol> symb;
+ if (sharedContext->symbols.TryGetValue(specMangledName, symb))
+ {
+ return (IRFunc*)(symb->irGlobalValue);
+ }
// TODO: This is a terrible linear search, and we should
// avoid it by building a dictionary ahead of time,
// as is being done for the `IRSpecContext` used above.
// We can probalby use the same basic context, actually.
- auto module = genericFunc->parentModule;
- for (auto gv = module->getFirstGlobalValue(); gv; gv = gv->getNextValue())
+ for (auto gv = sharedContext->module->getFirstGlobalValue(); gv; gv = gv->getNextValue())
{
if (gv->mangledName == specMangledName)
return (IRFunc*) gv;
@@ -4639,7 +4723,8 @@ namespace Slang
// are known, and specialize the callee based on those
// known values.
void specializeGenerics(
- IRModule* module)
+ IRModule* module,
+ CodeGenTarget target)
{
IRSharedSpecContext sharedContextStorage;
auto sharedContext = &sharedContextStorage;
@@ -4648,7 +4733,8 @@ namespace Slang
sharedContext,
module->session,
module,
- module);
+ module,
+ target);
// Our goal here is to find `specialize` instructions that
// can be replaced with references to a suitably sepcialized
@@ -4895,11 +4981,10 @@ namespace Slang
table = findWitnessTableByName(genericWitnessTableName);
SLANG_ASSERT(table);
WitnessTableSpecializationWorkItem workItem;
- workItem.srcTable = (IRWitnessTable*)table;
+ workItem.srcTable = (IRWitnessTable*)cloneGlobalValue(context, (IRWitnessTable*)(table));
workItem.dstTable = context->builder->createWitnessTable();
workItem.dstTable->mangledName = getMangledNameForConformanceWitness(subDeclRefType->declRef, subtypeWitness->sup);
workItem.specDeclRef = subDeclRefType->declRef;
- registerClonedValue(context, workItem.dstTable, workItem.srcTable);
witnessTablesToSpecailize.Add(workItem);
table = workItem.dstTable;
}
diff --git a/source/slang/legalize-types.cpp b/source/slang/legalize-types.cpp
index 211685aa2..d1cef4dac 100644
--- a/source/slang/legalize-types.cpp
+++ b/source/slang/legalize-types.cpp
@@ -916,7 +916,7 @@ LegalType legalizeType(
}
legalType = builder.getResult();
- context->mapDeclRefToLegalType.Add(declRef, legalType);
+ context->mapDeclRefToLegalType.AddIfNotExists(declRef, legalType);
return legalType;
}
diff --git a/source/slang/lower-to-ir.cpp b/source/slang/lower-to-ir.cpp
index 5e7e05a23..5d710725a 100644
--- a/source/slang/lower-to-ir.cpp
+++ b/source/slang/lower-to-ir.cpp
@@ -1021,6 +1021,7 @@ RefPtr<IRFuncType> getFuncType(
return funcType;
}
+SubstitutionSet lowerSubstitutions(IRGenContext* context, SubstitutionSet subst);
//
struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, LoweredTypeInfo>
@@ -1080,8 +1081,6 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
// TODO: actually test what module the type is coming from.
lowerDecl(context, type->declRef);
-
-
return LoweredTypeInfo(type);
}
@@ -3006,6 +3005,11 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
return globalVal;
}
+ LoweredValInfo visitGenericValueParamDecl(GenericValueParamDecl* decl)
+ {
+ return LoweredValInfo::simple(context->irBuilder->getDeclRefVal(DeclRefBase(decl)));
+ }
+
LoweredValInfo visitVarDeclBase(VarDeclBase* decl)
{
// Detect global (or effectively global) variables
@@ -3733,7 +3737,10 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
if (auto innerFuncDecl = genDecl->inner->As<FuncDecl>())
return lowerFuncDecl(innerFuncDecl);
else if (auto innerStructDecl = genDecl->inner->As<StructDecl>())
+ {
+ visitAggTypeDecl(innerStructDecl);
return LoweredValInfo();
+ }
SLANG_RELEASE_ASSERT(false);
UNREACHABLE_RETURN(LoweredValInfo());
}
@@ -3910,6 +3917,32 @@ RefPtr<GenericSubstitution> lowerGenericSubstitutions(
return result;
}
+RefPtr<GlobalGenericParamSubstitution> lowerGlobalGenericSubstitutions(
+ IRGenContext* context,
+ GlobalGenericParamSubstitution* genSubst)
+{
+ if (!genSubst)
+ return nullptr;
+ RefPtr<GlobalGenericParamSubstitution> result;
+ RefPtr<GlobalGenericParamSubstitution> newSubst = new GlobalGenericParamSubstitution();
+ newSubst->actualType = lowerSubstitutionArg(context, genSubst->actualType);
+ newSubst->paramDecl = genSubst->paramDecl;
+ for (auto & tbl : genSubst->witnessTables)
+ {
+ auto ntbl = tbl;
+ ntbl.Value = lowerSubstitutionArg(context, tbl.Value);
+ newSubst->witnessTables.Add(ntbl);
+ }
+ result = newSubst;
+ if (genSubst->outer)
+ {
+ result->outer = lowerGlobalGenericSubstitutions(
+ context,
+ genSubst->outer);
+ }
+ return result;
+}
+
RefPtr<ThisTypeSubstitution> lowerThisTypeSubstitution(
IRGenContext* context,
ThisTypeSubstitution* thisSubst)
@@ -3926,7 +3959,7 @@ SubstitutionSet lowerSubstitutions(IRGenContext* context, SubstitutionSet subst)
SubstitutionSet rs;
rs.genericSubstitutions = lowerGenericSubstitutions(context, subst.genericSubstitutions);
rs.thisTypeSubstitution = lowerThisTypeSubstitution(context, subst.thisTypeSubstitution);
- rs.globalGenParamSubstitutions = subst.globalGenParamSubstitutions;
+ rs.globalGenParamSubstitutions = lowerGlobalGenericSubstitutions(context, subst.globalGenParamSubstitutions);
return rs;
}
@@ -3973,10 +4006,10 @@ LoweredValInfo maybeEmitSpecializeInst(IRGenContext* context,
// need to walk through those and replace things in
// cases where the `Val`s used for substitution should
// lower to something other than their original form.
- auto lowedNewSubst = lowerGenericSubstitutions(context, newSubst);
- DeclRef<Decl> newDeclRef = DeclRef<Decl>(declRef.decl,
- SubstitutionSet(lowedNewSubst, declRef.substitutions.thisTypeSubstitution,
- declRef.substitutions.globalGenParamSubstitutions));
+ SubstitutionSet oldSubst = declRef.substitutions;
+ oldSubst.genericSubstitutions = newSubst;
+ auto lowedNewSubst = lowerSubstitutions(context, oldSubst);
+ DeclRef<Decl> newDeclRef = DeclRef<Decl>(declRef.decl, lowedNewSubst);
RefPtr<Type> type;
if (auto declType = val->getType())
@@ -4014,9 +4047,9 @@ static void lowerEntryPointToIR(
return;
}
// we need to lower all global type arguments as well
+ auto loweredEntryPointFunc = ensureDecl(context, entryPointFuncDecl);
for (auto arg : entryPointRequest->genericParameterTypes)
lowerType(context, arg);
- auto loweredEntryPointFunc = ensureDecl(context, entryPointFuncDecl);
}
#if 0
diff --git a/source/slang/parameter-binding.cpp b/source/slang/parameter-binding.cpp
index e5ea1d531..e1c5c1aca 100644
--- a/source/slang/parameter-binding.cpp
+++ b/source/slang/parameter-binding.cpp
@@ -2109,14 +2109,15 @@ RefPtr<ProgramLayout> specializeProgramLayout(
auto constantBufferRules = context.getRulesFamily()->getConstantBufferRules();
structLayout->rules = constantBufferRules;
-
+ structLayout->fields.SetSize(globalStructLayout->fields.Count());
UniformLayoutInfo structLayoutInfo;
structLayoutInfo.alignment = globalStructLayout->uniformAlignment;
structLayoutInfo.size = 0;
bool anyUniforms = false;
Dictionary<RefPtr<VarLayout>, RefPtr<VarLayout>> varLayoutMapping;
- for (auto & varLayout : globalStructLayout->fields)
+ for (uint32_t varId = 0; varId < globalStructLayout->fields.Count(); varId++)
{
+ auto &varLayout = globalStructLayout->fields[varId];
// To recover layout context, we skip generic resources in the first pass
if (varLayout->FindResourceInfo(LayoutResourceKind::GenericResource))
continue;
@@ -2141,7 +2142,7 @@ RefPtr<ProgramLayout> specializeProgramLayout(
resInfo.index,
resInfo.index + tresInfo.count);
}
- structLayout->fields.Add(varLayout);
+ structLayout->fields[varId] = varLayout;
varLayoutMapping[varLayout] = varLayout;
}
auto originalGlobalCBufferInfo = programLayout->globalScopeLayout->FindResourceInfo(LayoutResourceKind::ConstantBuffer);
@@ -2156,8 +2157,9 @@ RefPtr<ProgramLayout> specializeProgramLayout(
globalCBufferInfo.index = originalGlobalCBufferInfo->index;
}
// we have the context restored, can continue to layout the generic variables now
- for (auto & varLayout : globalStructLayout->fields)
+ for (uint32_t varId = 0; varId < globalStructLayout->fields.Count(); varId++)
{
+ auto &varLayout = globalStructLayout->fields[varId];
if (varLayout->typeLayout->FindResourceInfo(LayoutResourceKind::GenericResource))
{
RefPtr<Type> newType = varLayout->typeLayout->type->Substitute(typeSubst).As<Type>();
@@ -2202,7 +2204,7 @@ RefPtr<ProgramLayout> specializeProgramLayout(
newVarLayout->findOrAddResourceInfo(LayoutResourceKind::Uniform)->index = uniformOffset;
anyUniforms = true;
}
- structLayout->fields.Add(newVarLayout);
+ structLayout->fields[varId] = newVarLayout;
varLayoutMapping[varLayout] = newVarLayout;
}
}
diff --git a/source/slang/parser.cpp b/source/slang/parser.cpp
index 7e36b0e71..531606f8d 100644
--- a/source/slang/parser.cpp
+++ b/source/slang/parser.cpp
@@ -2704,6 +2704,7 @@ namespace Slang
PushScope(program);
program->loc = tokenReader.PeekLoc();
+ program->scope = currentScope;
ParseDeclBody(this, program, TokenType::EndOfFile);
PopScope();
@@ -3960,6 +3961,17 @@ namespace Slang
return parsePrefixExpr(this);
}
+ RefPtr<Expr> parseTypeFromSourceFile(TranslationUnitRequest* translationUnit,
+ TokenSpan const& tokens,
+ DiagnosticSink* sink,
+ RefPtr<Scope> const& outerScope)
+ {
+ Parser parser(tokens, sink, outerScope);
+ parser.translationUnit = translationUnit;
+ parser.currentScope = outerScope;
+ return parser.ParseType();
+ }
+
// Parse a source file into an existing translation unit
void parseSourceFile(
TranslationUnitRequest* translationUnit,
@@ -3971,6 +3983,7 @@ namespace Slang
parser.translationUnit = translationUnit;
+
return parser.parseSourceFile(translationUnit->SyntaxNode.Ptr());
}
diff --git a/source/slang/parser.h b/source/slang/parser.h
index 60fe4b3ae..785b6e345 100644
--- a/source/slang/parser.h
+++ b/source/slang/parser.h
@@ -14,6 +14,11 @@ namespace Slang
DiagnosticSink* sink,
RefPtr<Scope> const& outerScope);
+ RefPtr<Expr> parseTypeFromSourceFile(TranslationUnitRequest* translationUnit,
+ TokenSpan const& tokens,
+ DiagnosticSink* sink,
+ RefPtr<Scope> const& outerScope);
+
RefPtr<ModuleDecl> populateBaseLanguageModule(
Session* session,
RefPtr<Scope> scope);
diff --git a/source/slang/reflection.cpp b/source/slang/reflection.cpp
index c9de75d6e..b0be58274 100644
--- a/source/slang/reflection.cpp
+++ b/source/slang/reflection.cpp
@@ -433,20 +433,8 @@ SLANG_API SlangReflectionType * spReflection_FindTypeByName(SlangReflection * re
auto context = convert(reflection);
auto compileRequest = context->targetRequest->compileRequest;
- RefPtr<Type> result;
- if (compileRequest->types.TryGetValue(name, result))
- return (SlangReflectionType*)result.Ptr();
-
- auto nameObj = compileRequest->getNamePool()->getName(name);
- Decl* resultDecl = compileRequest->lookupGlobalDecl(nameObj);
- if (resultDecl)
- {
- RefPtr<DeclRefType> declRefType = new DeclRefType();
- declRefType->declRef.decl = resultDecl;
- compileRequest->types[name] = declRefType;
- return (SlangReflectionType*)declRefType.Ptr();
- }
- return nullptr;
+ RefPtr<Type> result = compileRequest->getTypeFromString(name);
+ return (SlangReflectionType*)result.Ptr();
}
SLANG_API SlangReflectionTypeLayout* spReflection_GetTypeLayout(
diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp
index 2ebf024e3..4c9ecf8a8 100644
--- a/source/slang/slang.cpp
+++ b/source/slang/slang.cpp
@@ -125,6 +125,50 @@ CompileRequest::CompileRequest(Session* session)
CompileRequest::~CompileRequest()
{}
+
+RefPtr<Expr> CompileRequest::parseTypeString(TranslationUnitRequest * translationUnit, String typeStr, RefPtr<Scope> scope)
+{
+ Slang::SourceFile srcFile;
+ srcFile.content = typeStr;
+ DiagnosticSink sink;
+ sink.sourceManager = sourceManager;
+ auto tokens = preprocessSource(
+ &srcFile,
+ &sink,
+ nullptr,
+ Dictionary<String,String>(),
+ translationUnit);
+ return parseTypeFromSourceFile(translationUnit, tokens, &sink, scope);
+}
+
+RefPtr<Type> checkProperType(TranslationUnitRequest * tu, TypeExp typeExp);
+Type* CompileRequest::getTypeFromString(String typeStr)
+{
+ RefPtr<Type> type;
+ if (types.TryGetValue(typeStr, type))
+ return type;
+ auto translationUnit = translationUnits.First();
+ List<RefPtr<Scope>> scopesToTry;
+ for (auto tu : translationUnits)
+ scopesToTry.Add(tu->SyntaxNode->scope);
+ for (auto & module : loadedModulesList)
+ scopesToTry.Add(module->moduleDecl->scope);
+ // parse type name
+ for (auto & s : scopesToTry)
+ {
+ RefPtr<Expr> typeExpr = parseTypeString(translationUnit,
+ typeStr, s);
+ type = checkProperType(translationUnit, TypeExp(typeExpr));
+ if (type)
+ break;
+ }
+ if (type)
+ {
+ types[typeStr] = type;
+ }
+ return type.Ptr();
+}
+
void CompileRequest::parseTranslationUnit(
TranslationUnitRequest* translationUnit)
{
@@ -429,7 +473,7 @@ int CompileRequest::addEntryPoint(
entryPoint->profile = entryPointProfile;
entryPoint->translationUnitIndex = translationUnitIndex;
for (auto typeName : genericTypeNames)
- entryPoint->genericParameterTypeNames.Add(getNamePool()->getName(typeName));
+ entryPoint->genericParameterTypeNames.Add(typeName);
auto translationUnit = translationUnits[translationUnitIndex].Ptr();
translationUnit->entryPoints.Add(entryPoint);
diff --git a/source/slang/syntax.cpp b/source/slang/syntax.cpp
index 552f1dc56..ab4a5f94c 100644
--- a/source/slang/syntax.cpp
+++ b/source/slang/syntax.cpp
@@ -355,19 +355,21 @@ void Type::accept(IValVisitor* visitor, void* extra)
auto arrType = type->AsArrayType();
if (!arrType)
return false;
- return (ArrayLength == arrType->ArrayLength && baseType->Equals(arrType->baseType.Ptr()));
+ return (ArrayLength->EqualsVal(arrType->ArrayLength) && baseType->Equals(arrType->baseType.Ptr()));
}
RefPtr<Val> ArrayExpressionType::SubstituteImpl(SubstitutionSet subst, int* ioDiff)
{
int diff = 0;
auto elementType = baseType->SubstituteImpl(subst, &diff).As<Type>();
+ auto arrlen = ArrayLength->SubstituteImpl(subst, &diff).As<IntVal>();
+ SLANG_ASSERT(arrlen);
if (diff)
{
*ioDiff = 1;
auto rsType = getArrayType(
elementType,
- ArrayLength);
+ arrlen);
return rsType;
}
return this;
diff --git a/tests/compute/array-param.slang b/tests/compute/array-param.slang
new file mode 100644
index 000000000..78ca52518
--- /dev/null
+++ b/tests/compute/array-param.slang
@@ -0,0 +1,19 @@
+//TEST(compute):COMPARE_COMPUTE:-xslang -use-ir
+//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):dxbinding(0),glbinding(0),out
+
+RWStructuredBuffer<int> outputBuffer;
+void writeArray(inout float3 a[4])
+{
+ a[0] = float3(1, 1, 1);
+ a[1] = float3(1, 1, 1);
+ a[2] = float3(1, 1, 1);
+ a[3] = float3(1, 1, 1);
+}
+
+[numthreads(4, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ float3 b[4];
+ writeArray(b);
+ outputBuffer[dispatchThreadID.x] = b[0].x;
+} \ No newline at end of file
diff --git a/tests/compute/array-param.slang.expected.txt b/tests/compute/array-param.slang.expected.txt
new file mode 100644
index 000000000..ef529012e
--- /dev/null
+++ b/tests/compute/array-param.slang.expected.txt
@@ -0,0 +1,4 @@
+1
+1
+1
+1 \ No newline at end of file
diff --git a/tests/compute/global-type-param3.slang b/tests/compute/global-type-param-array.slang
index 05793dce4..74e52d5d4 100644
--- a/tests/compute/global-type-param3.slang
+++ b/tests/compute/global-type-param-array.slang
@@ -1,23 +1,10 @@
-//TEST(smoke,compute):COMPARE_COMPUTE:-xslang -use-ir
-//TEST_INPUT: cbuffer(data=[1.0], stride=4):dxbinding(0),glbinding(0)
+//TEST(compute):COMPARE_COMPUTE:-xslang -use-ir
+//TEST_INPUT: cbuffer(data=[1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0 1.0], stride=4):dxbinding(0),glbinding(0)
//TEST_INPUT: ubuffer(data=[0], stride=4):dxbinding(0),glbinding(0),out
-//TEST_INPUT: type Impl
+//TEST_INPUT: type Pair<Arr<Base,1>, Pair<Arr<Base,2> , Base> >
RWStructuredBuffer<float> outputBuffer;
-
-interface IBase
-{
- float compute();
-}
-
-struct Impl : IBase
-{
- float base; // = 1.0
- float compute()
- {
- return 1.0;
- }
-};
+import globalTypeParamArrayShared;
__generic_param TImpl : IBase;
@@ -25,7 +12,7 @@ ParameterBlock<TImpl> impl;
float doCompute<T:IBase>(T t)
{
- return t.compute();
+ return t.compute(1.0);
}
[numthreads(1, 1, 1)]
diff --git a/tests/compute/global-type-param-array.slang.expected.txt b/tests/compute/global-type-param-array.slang.expected.txt
new file mode 100644
index 000000000..bdf6b77dc
--- /dev/null
+++ b/tests/compute/global-type-param-array.slang.expected.txt
@@ -0,0 +1 @@
+40800000
diff --git a/tests/compute/global-type-param.slang b/tests/compute/global-type-param.slang
index 301ef1021..03f5df329 100644
--- a/tests/compute/global-type-param.slang
+++ b/tests/compute/global-type-param.slang
@@ -1,6 +1,6 @@
//TEST(smoke,compute):COMPARE_COMPUTE:-xslang -use-ir
//TEST_INPUT:ubuffer(data=[0], stride=4):dxbinding(0),glbinding(0),out
-//TEST_INPUT:type Impl
+//TEST_INPUT:type Wrapper<Impl>
RWStructuredBuffer<float> outputBuffer;
@@ -9,6 +9,15 @@ interface IBase
float compute();
}
+struct Wrapper<T : IBase> : IBase
+{
+ T obj;
+ float compute()
+ {
+ return obj.compute();
+ }
+};
+
struct Impl : IBase
{
float compute()
diff --git a/tests/compute/global-type-param3.slang.expected.txt b/tests/compute/global-type-param3.slang.expected.txt
deleted file mode 100644
index deb1c3630..000000000
--- a/tests/compute/global-type-param3.slang.expected.txt
+++ /dev/null
@@ -1 +0,0 @@
-3F800000
diff --git a/tests/compute/globalTypeParamArrayShared.slang b/tests/compute/globalTypeParamArrayShared.slang
new file mode 100644
index 000000000..ee3caa372
--- /dev/null
+++ b/tests/compute/globalTypeParamArrayShared.slang
@@ -0,0 +1,32 @@
+//TEST_IGNORE_FILE:
+interface IBase
+{
+ float compute<T>(T g);
+}
+struct Base:IBase
+{
+ float b;
+ float compute<T>(T g) { return b; }
+};
+
+struct Pair<T1:IBase, T2:IBase> : IBase
+{
+ T1 head;
+ T2 tail;
+ float compute<T>(T g)
+ {
+ return head.compute(g) + tail.compute(g);
+ }
+};
+
+struct Arr<T:IBase, let N:int> : IBase
+{
+ T base[N]; // = 1.0
+ float compute<T>(T g)
+ {
+ float sum = 0.0;
+ for (int i = 0; i < N; i++)
+ sum += base[i].compute(g);
+ return sum;
+ }
+};
diff --git a/tools/render-test/shader-input-layout.cpp b/tools/render-test/shader-input-layout.cpp
index 01328eabd..fcf25f376 100644
--- a/tools/render-test/shader-input-layout.cpp
+++ b/tools/render-test/shader-input-layout.cpp
@@ -20,7 +20,10 @@ namespace renderer_test
if (parser.LookAhead("type"))
{
parser.ReadToken();
- globalTypeArguments.Add(parser.ReadWord());
+ StringBuilder typeExp;
+ while (!parser.IsEnd())
+ typeExp << parser.ReadToken().Content;
+ globalTypeArguments.Add(typeExp);
}
else
{
diff --git a/tools/render-test/slang-support.cpp b/tools/render-test/slang-support.cpp
index cfbc24382..9263aa41b 100644
--- a/tools/render-test/slang-support.cpp
+++ b/tools/render-test/slang-support.cpp
@@ -100,6 +100,7 @@ struct SlangShaderCompilerWrapper : public ShaderCompiler
(int)rawTypeNames.Count(),
rawTypeNames.Buffer());
int compileErr = spCompile(slangRequest);
+ spSetLineDirectiveMode(slangRequest, SLANG_LINE_DIRECTIVE_MODE_NONE);
if (auto diagnostics = spGetDiagnosticOutput(slangRequest))
{
fprintf(stderr, "%s", diagnostics);