summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--slang.h28
-rw-r--r--source/core/list.h8
-rw-r--r--source/slang/check.cpp102
-rw-r--r--source/slang/compiler.cpp6
-rw-r--r--source/slang/compiler.h12
-rw-r--r--source/slang/decl-defs.h5
-rw-r--r--source/slang/diagnostic-defs.h9
-rw-r--r--source/slang/emit.cpp11
-rw-r--r--source/slang/emit.h5
-rw-r--r--source/slang/ir-insts.h4
-rw-r--r--source/slang/ir.cpp119
-rw-r--r--source/slang/lookup.cpp7
-rw-r--r--source/slang/lower-to-ir.cpp4
-rw-r--r--source/slang/lower.cpp7
-rw-r--r--source/slang/parameter-binding.cpp239
-rw-r--r--source/slang/parser.cpp103
-rw-r--r--source/slang/reflection.cpp19
-rw-r--r--source/slang/slang.cpp32
-rw-r--r--source/slang/syntax-base-defs.h32
-rw-r--r--source/slang/syntax.cpp62
-rw-r--r--source/slang/syntax.h2
-rw-r--r--source/slang/type-layout.cpp32
-rw-r--r--source/slang/type-layout.h20
-rw-r--r--tests/compute/global-type-param.slang30
-rw-r--r--tests/compute/global-type-param.slang.expected.txt1
-rw-r--r--tests/compute/global-type-param1.slang46
-rw-r--r--tests/compute/global-type-param1.slang.expected.txt1
-rw-r--r--tests/compute/global-type-param2.slang61
-rw-r--r--tests/compute/global-type-param2.slang.expected.txt1
-rw-r--r--tools/render-test/main.cpp1
-rw-r--r--tools/render-test/render-d3d11.cpp20
-rw-r--r--tools/render-test/render.h1
-rw-r--r--tools/render-test/shader-input-layout.cpp357
-rw-r--r--tools/render-test/shader-input-layout.h1
-rw-r--r--tools/render-test/slang-support.cpp9
35 files changed, 1123 insertions, 274 deletions
diff --git a/slang.h b/slang.h
index c4c2f54b7..30808a242 100644
--- a/slang.h
+++ b/slang.h
@@ -377,6 +377,18 @@ extern "C"
char const* name,
SlangProfileID profile);
+ /** Add an entry point in a particular translation unit,
+ with additional arguments that specify the concrete
+ type names for global generic type parameters.
+ */
+ SLANG_API int spAddEntryPointEx(
+ SlangCompileRequest* request,
+ int translationUnitIndex,
+ char const* name,
+ SlangProfileID profile,
+ int genericTypeNameCount,
+ char const** genericTypeNames);
+
/** Execute the compilation request.
Returns zero on success, non-zero on failure.
@@ -588,6 +600,9 @@ extern "C"
// HLSL register `space`, Vulkan GLSL `set`
SLANG_PARAMETER_CATEGORY_REGISTER_SPACE,
+ // A parameter whose type is to be specialized by a global generic type argument
+ SLANG_PARAMETER_CATEGORY_GENERIC,
+
//
SLANG_PARAMETER_CATEGORY_COUNT,
};
@@ -695,6 +710,8 @@ extern "C"
SLANG_API SlangUInt spReflection_getEntryPointCount(SlangReflection* reflection);
SLANG_API SlangReflectionEntryPoint* spReflection_getEntryPointByIndex(SlangReflection* reflection, SlangUInt index);
+ SLANG_API SlangUInt spReflection_getGlobalConstantBufferBinding(SlangReflection* reflection);
+ SLANG_API size_t spReflection_getGlobalConstantBufferSize(SlangReflection* reflection);
#ifdef __cplusplus
}
@@ -848,6 +865,7 @@ namespace slang
SpecializationConstant = SLANG_PARAMETER_CATEGORY_SPECIALIZATION_CONSTANT,
PushConstantBuffer = SLANG_PARAMETER_CATEGORY_PUSH_CONSTANT_BUFFER,
RegisterSpace = SLANG_PARAMETER_CATEGORY_REGISTER_SPACE,
+ GenericResource = SLANG_PARAMETER_CATEGORY_GENERIC,
};
struct TypeLayoutReflection
@@ -1102,6 +1120,16 @@ namespace slang
{
return (EntryPointReflection*) spReflection_getEntryPointByIndex((SlangReflection*) this, index);
}
+
+ SlangUInt getGlobalConstantBufferBinding()
+ {
+ return spReflection_getGlobalConstantBufferBinding((SlangReflection*)this);
+ }
+
+ size_t getGlobalConstantBufferSize()
+ {
+ return spReflection_getGlobalConstantBufferSize((SlangReflection*)this);
+ }
};
}
diff --git a/source/core/list.h b/source/core/list.h
index af32a39ef..b1461a260 100644
--- a/source/core/list.h
+++ b/source/core/list.h
@@ -487,7 +487,7 @@ namespace Slang
if (predicate(buffer[i]))
return i;
}
- return -1;
+ return (UInt)-1;
}
template<typename Func>
@@ -498,7 +498,7 @@ namespace Slang
if (predicate(buffer[i-1]))
return i-1;
}
- return -1;
+ return (UInt)-1;
}
template<typename T2>
@@ -509,7 +509,7 @@ namespace Slang
if (buffer[i] == val)
return i;
}
- return -1;
+ return (UInt)-1;
}
template<typename T2>
@@ -520,7 +520,7 @@ namespace Slang
if(buffer[i-1] == val)
return i-1;
}
- return -1;
+ return (UInt)-1;
}
void Sort()
diff --git a/source/slang/check.cpp b/source/slang/check.cpp
index 233a82eef..4b8f4f4c1 100644
--- a/source/slang/check.cpp
+++ b/source/slang/check.cpp
@@ -148,7 +148,6 @@ namespace Slang
return expr->type->As<DeclRefType>();
}
-
RefPtr<Expr> ConstructDeclRefExpr(
DeclRef<Decl> declRef,
RefPtr<Expr> baseExpr,
@@ -1998,6 +1997,22 @@ namespace Slang
decl->SetCheckState(DeclCheckState::Checked);
}
+ void visitGlobalGenericParamDecl(GlobalGenericParamDecl * decl)
+ {
+ if (decl->IsChecked(DeclCheckState::Checked)) return;
+ decl->SetCheckState(DeclCheckState::CheckedHeader);
+ // global generic param only allowed in global scope
+ auto program = decl->ParentDecl->As<ModuleDecl>();
+ if (!program)
+ getSink()->diagnose(decl, Slang::Diagnostics::globalGenParamInGlobalScopeOnly);
+ // Now check all of the member declarations.
+ for (auto member : decl->Members)
+ {
+ checkDecl(member);
+ }
+ decl->SetCheckState(DeclCheckState::Checked);
+ }
+
void visitAssocTypeDecl(AssocTypeDecl* decl)
{
if (decl->IsChecked(DeclCheckState::Checked)) return;
@@ -3703,6 +3718,19 @@ namespace Slang
return true;
}
}
+ // if an inheritance decl is not found, try to find a GenericTypeConstraintDecl
+ for (auto genConstraintDeclRef : getMembersOfType<GenericTypeConstraintDecl>(aggTypeDeclRef))
+ {
+ EnsureDecl(genConstraintDeclRef.getDecl());
+ auto inheritedType = GetSup(genConstraintDeclRef);
+ TypeWitnessBreadcrumb breadcrumb;
+ breadcrumb.prev = inBreadcrumbs;
+ breadcrumb.declRef = genConstraintDeclRef;
+ if (doesTypeConformToInterfaceImpl(originalType, inheritedType, interfaceDeclRef, outWitness, &breadcrumb))
+ {
+ return true;
+ }
+ }
}
else if( auto genericTypeParamDeclRef = declRef.As<GenericTypeParamDecl>() )
{
@@ -6582,6 +6610,78 @@ namespace Slang
// that we don't have to re-do this effort again later.
entryPoint->decl = entryPointFuncDecl;
+ // Lookup generic parameter types in global scope
+ for (auto name : entryPoint->genericParameterTypeNames)
+ {
+ if (!translationUnitSyntax->memberDictionary.TryGetValue(name, firstDeclWithName))
+ {
+ // If there doesn't appear to be any such declaration, then
+ // we need to diagnose it as an error, and then bail out.
+ sink->diagnose(translationUnitSyntax, Diagnostics::entryPointTypeParameterNotFound, name);
+ return;
+ }
+ RefPtr<Type> type;
+ if (auto aggType = firstDeclWithName->As<AggTypeDecl>())
+ {
+ type = DeclRefType::Create(entryPoint->compileRequest->mSession, DeclRef<Decl>(aggType, nullptr));
+ }
+ else if (auto typeDefDecl = firstDeclWithName->As<TypeDefDecl>())
+ {
+ type = GetType(DeclRef<TypeDefDecl>(typeDefDecl, nullptr));
+ }
+ else
+ {
+ sink->diagnose(firstDeclWithName, Diagnostics::entryPointTypeSymbolNotAType, name);
+ return;
+ }
+ entryPoint->genericParameterTypes.Add(type);
+ }
+ // check that user-provioded type arguments conforms to the generic type
+ // parameter declaration of this translation unit
+
+ // collect global generic parameters from all imported modules
+ List<RefPtr<GlobalGenericParamDecl>> globalGenericParams;
+ // add current translation unit first
+ {
+ auto globalGenParams = translationUnit->SyntaxNode->getMembersOfType<GlobalGenericParamDecl>();
+ for (auto p : globalGenParams)
+ globalGenericParams.Add(p);
+ }
+ // add imported modules
+ for (auto moduleDecl : entryPoint->compileRequest->loadedModulesList)
+ {
+ auto globalGenParams = moduleDecl->getMembersOfType<GlobalGenericParamDecl>();
+ for (auto p : globalGenParams)
+ globalGenericParams.Add(p);
+ }
+ if (globalGenericParams.Count() != entryPoint->genericParameterTypes.Count())
+ {
+ sink->diagnose(entryPoint->decl, Diagnostics::mismatchEntryPointTypeArgument, globalGenericParams.Count(),
+ entryPoint->genericParameterTypes.Count());
+ return;
+ }
+ // if number of entry-point type arguments matches parameters, try find
+ // SubtypeWitness for each argument
+ int index = 0;
+ for (auto & gParam : globalGenericParams)
+ {
+ for (auto constraint : gParam->getMembersOfType<GenericTypeConstraintDecl>())
+ {
+ auto interfaceType = GetSup(DeclRef<GenericTypeConstraintDecl>(constraint, nullptr));
+ SemanticsVisitor visitor(sink, entryPoint->compileRequest, translationUnit);
+ auto witness = visitor.tryGetSubtypeWitness(entryPoint->genericParameterTypes[index], interfaceType);
+ if (!witness)
+ {
+ sink->diagnose(gParam,
+ Diagnostics::typeArgumentDoesNotConformToInterface, gParam->nameAndLoc.name, entryPoint->genericParameterTypes[index],
+ interfaceType);
+ }
+ entryPoint->genericParameterWitnesses.Add(witness);
+ }
+ index++;
+ }
+ if (sink->errorCount != 0)
+ return;
// TODO: after all that work, we are now in a position to start
// validating the declaration itself. E.g., we should check if
// the declared input/output parameters have suitable semantics,
diff --git a/source/slang/compiler.cpp b/source/slang/compiler.cpp
index 302b5704f..acbf51e2e 100644
--- a/source/slang/compiler.cpp
+++ b/source/slang/compiler.cpp
@@ -11,7 +11,7 @@
#include "parser.h"
#include "preprocessor.h"
#include "syntax-visitors.h"
-
+#include "type-layout.h"
#include "reflection.h"
#include "emit.h"
@@ -160,7 +160,7 @@ namespace Slang
entryPoint,
targetReq->layout.Ptr(),
CodeGenTarget::HLSL,
- targetReq->target);
+ targetReq);
}
}
@@ -207,7 +207,7 @@ namespace Slang
entryPoint,
targetReq->layout.Ptr(),
CodeGenTarget::GLSL,
- targetReq->target);
+ targetReq);
}
}
diff --git a/source/slang/compiler.h b/source/slang/compiler.h
index f42f36c1f..303be6624 100644
--- a/source/slang/compiler.h
+++ b/source/slang/compiler.h
@@ -100,6 +100,10 @@ namespace Slang
// The name of the entry point function (e.g., `main`)
Name* name;
+
+ // The type names we want to substitute into the
+ // global generic type parameters
+ List<Name*> genericParameterTypeNames;
// The profile that the entry point will be compiled for
// (this is a combination of the target state, and also
@@ -123,6 +127,11 @@ namespace Slang
// it should not be assumed to be available in cases
// where any errors were diagnosed.
RefPtr<FuncDecl> decl;
+
+ // The declaration of the global generic parameter types
+ // This will be filled in as part of semantic analysis.
+ List<RefPtr<Type>> genericParameterTypes;
+ List<RefPtr<Val>> genericParameterWitnesses;
};
enum class PassThroughMode : SlangPassThrough
@@ -319,7 +328,8 @@ namespace Slang
int addEntryPoint(
int translationUnitIndex,
String const& name,
- Profile profile);
+ Profile profile,
+ List<String> const & genericTypeNames);
UInt addTarget(
CodeGenTarget target);
diff --git a/source/slang/decl-defs.h b/source/slang/decl-defs.h
index 9c010d156..e24a535c5 100644
--- a/source/slang/decl-defs.h
+++ b/source/slang/decl-defs.h
@@ -126,6 +126,11 @@ END_SYNTAX_CLASS()
SYNTAX_CLASS(AssocTypeDecl, AggTypeDecl)
END_SYNTAX_CLASS()
+// A '__generic_param' declaration, which defines a generic
+// entry-point parameter. Is a container of GenericTypeConstraintDecl
+SYNTAX_CLASS(GlobalGenericParamDecl, AggTypeDecl)
+END_SYNTAX_CLASS()
+
// A scope for local declarations (e.g., as part of a statement)
SIMPLE_SYNTAX_CLASS(ScopeDecl, ContainerDecl)
diff --git a/source/slang/diagnostic-defs.h b/source/slang/diagnostic-defs.h
index 7f27e43e8..24e8bc713 100644
--- a/source/slang/diagnostic-defs.h
+++ b/source/slang/diagnostic-defs.h
@@ -196,7 +196,7 @@ DIAGNOSTIC(33070, Error, expectedFunction, "expression preceding parenthesis of
// 303xx: interfaces and associated types
DIAGNOSTIC(30300, Error, assocTypeInInterfaceOnly, "'associatedtype' can only be defined in an 'interface'.")
-
+DIAGNOSTIC(30301, Error, globalGenParamInGlobalScopeOnly, "'__generic_param' can only be defined global scope.")
// TODO: need to assign numbers to all these extra diagnostics...
DIAGNOSTIC(39999, Error, expectedIntegerConstantWrongType, "expected integer constant (found: '$0')")
@@ -244,11 +244,17 @@ DIAGNOSTIC(38001, Error, ambiguousEntryPoint, "more than one function matches en
DIAGNOSTIC(38002, Note, entryPointCandidate, "see candidate declaration for entry point '$0'")
DIAGNOSTIC(38003, Error, entryPointSymbolNotAFunction, "entry point '$0' must be declared as a function")
+DIAGNOSTIC(38004, Error, entryPointTypeParameterNotFound, "no type found matching entry-point type parameter name '$0'")
+DIAGNOSTIC(38005, Error, entryPointTypeSymbolNotAType, "entry-point type parameter '$0' must be declared as a type")
+
DIAGNOSTIC(38100, Error, typeDoesntImplementInterfaceRequirement, "type '$0' does not provide required interface member '$1'")
DIAGNOSTIC(38101, Error, thisExpressionOutsideOfTypeDecl, "'this' expression can only be used in members of an aggregate type")
DIAGNOSTIC(38102, Error, initializerNotInsideType, "an 'init' declaration is only allowed inside a type or 'extension' declaration")
DIAGNOSTIC(38102, Error, accessorMustBeInsideSubscriptOrProperty, "an accessor declaration is only allowed inside a subscript or property declaration")
+DIAGNOSTIC(38020, Error, mismatchEntryPointTypeArgument, "expecting $0 entry-point type arguments, provided $1.")
+DIAGNOSTIC(38021, Error, typeArgumentDoesNotConformToInterface, "type argument `$1` for generic parameter `$0` does not conform to interface `$1`.")
+
//
// 4xxxx - IL code generation.
//
@@ -264,7 +270,6 @@ DIAGNOSTIC(49999, Error, unknownSystemValueSemantic, "unknown system-value seman
//
// 5xxxx - Target code generation.
//
-
DIAGNOSTIC(50020, Error, unknownStageType, "Unknown stage type '$0'.")
DIAGNOSTIC(50020, Error, invalidTessCoordType, "TessCoord must have vec2 or vec3 type.")
DIAGNOSTIC(50020, Error, invalidFragCoordType, "FragCoord must be a vec4.")
diff --git a/source/slang/emit.cpp b/source/slang/emit.cpp
index 5b7a42ad7..614e8f474 100644
--- a/source/slang/emit.cpp
+++ b/source/slang/emit.cpp
@@ -3481,9 +3481,9 @@ struct EmitVisitor
break;
case LayoutResourceKind::RegisterSpace:
+ case LayoutResourceKind::GenericResource:
// ignore
break;
-
default:
{
Emit(": register(");
@@ -6771,7 +6771,7 @@ EntryPointLayout* findEntryPointLayout(
StructTypeLayout* getGlobalStructLayout(
ProgramLayout* programLayout)
{
- auto globalScopeLayout = programLayout->globalScopeLayout;
+ auto globalScopeLayout = programLayout->globalScopeLayout->typeLayout;
if( auto gs = globalScopeLayout.As<StructTypeLayout>() )
{
return gs.Ptr();
@@ -6816,13 +6816,13 @@ String emitEntryPoint(
EntryPointRequest* entryPoint,
ProgramLayout* programLayout,
CodeGenTarget target,
- CodeGenTarget finalTarget)
+ TargetRequest* targetRequest)
{
auto translationUnit = entryPoint->getTranslationUnit();
SharedEmitContext sharedContext;
sharedContext.target = target;
- sharedContext.finalTarget = finalTarget;
+ sharedContext.finalTarget = targetRequest->target;
sharedContext.entryPoint = entryPoint;
if (entryPoint)
@@ -6890,7 +6890,8 @@ String emitEntryPoint(
auto lowered = specializeIRForEntryPoint(
entryPoint,
programLayout,
- target);
+ target,
+ targetRequest);
// If the user specified the flag that they want us to dump
// IR, then do it here, for the target-specific, but
diff --git a/source/slang/emit.h b/source/slang/emit.h
index e17a84d5a..98845f9c6 100644
--- a/source/slang/emit.h
+++ b/source/slang/emit.h
@@ -26,8 +26,7 @@ namespace Slang
// The target language to generate code in (e.g., HLSL/GLSL)
CodeGenTarget target,
- // The "final" target that we are being asked to compile for
- // (e.g., SPIR-V, DXBC, ...).
- CodeGenTarget finalTarget);
+ // The full target request
+ TargetRequest* targetRequest);
}
#endif
diff --git a/source/slang/ir-insts.h b/source/slang/ir-insts.h
index 52acf6576..a91143a43 100644
--- a/source/slang/ir-insts.h
+++ b/source/slang/ir-insts.h
@@ -419,6 +419,7 @@ struct IRBuilder
IRFunc* createFunc();
IRGlobalVar* createGlobalVar(
IRType* valueType);
+ IRWitnessTable* createWitnessTable(Dictionary<DeclRef<Decl>, Decl*> & witnesses);
IRWitnessTable* createWitnessTable();
IRWitnessTableEntry* createWitnessTableEntry(
IRWitnessTable* witnessTable,
@@ -565,7 +566,8 @@ struct IRBuilder
IRModule* specializeIRForEntryPoint(
EntryPointRequest* entryPointRequest,
ProgramLayout* programLayout,
- CodeGenTarget target);
+ CodeGenTarget target,
+ TargetRequest* targetReq);
// Find suitable uses of the `specialize` instruction that
// can be replaced with references to specialized functions.
diff --git a/source/slang/ir.cpp b/source/slang/ir.cpp
index 9068e717b..bfc26643c 100644
--- a/source/slang/ir.cpp
+++ b/source/slang/ir.cpp
@@ -3089,12 +3089,16 @@ namespace Slang
// to the layout to use for it.
Dictionary<String, VarLayout*> globalVarLayouts;
+ RefPtr<GlobalGenericParamSubstitution> subst;
+
// Override the "maybe clone" logic so that we always clone
virtual IRValue* maybeCloneValue(IRValue* originalVal) override;
// Override teh "maybe clone" logic so that we carefully
// clone any IR proxy values inside substitutions
virtual DeclRef<Decl> maybeCloneDeclRef(DeclRef<Decl> const& declRef) override;
+
+ virtual RefPtr<Type> maybeCloneType(Type* originalType) override;
};
@@ -3102,6 +3106,11 @@ namespace Slang
IRFunc* cloneFunc(IRSpecContext* context, IRFunc* originalFunc);
IRWitnessTable* cloneWitnessTable(IRSpecContext* context, IRWitnessTable* originalVar);
+ RefPtr<Type> IRSpecContext::maybeCloneType(Type* originalType)
+ {
+ return originalType->Substitute(subst).As<Type>();
+ }
+
IRValue* IRSpecContext::maybeCloneValue(IRValue* originalValue)
{
switch (originalValue->op)
@@ -3143,6 +3152,33 @@ namespace Slang
case kIROp_decl_ref:
{
IRDeclRef* od = (IRDeclRef*)originalValue;
+
+ // if the declRef is one of the __generic_param decl being substituted by subst
+ // return the substituted decl
+ if (subst)
+ {
+ if (od->declRef.getDecl() == subst->paramDecl)
+ return builder->getTypeVal(subst->actualType.As<Type>());
+ else if (auto genConstraint = od->declRef.As<GenericTypeConstraintDecl>())
+ {
+ // a decl-ref to GenericTypeConstraintDecl as a result of
+ // referencing a generic parameter type should be replaced with
+ // the actual witness table
+ if (genConstraint.getDecl()->ParentDecl == subst->paramDecl)
+ {
+ // find the witness table from subst
+ for (auto witness : subst->witnessTables)
+ {
+ if (witness.Key->EqualsVal(GetSup(genConstraint)))
+ {
+ auto proxyVal = witness.Value.As<IRProxyVal>();
+ SLANG_ASSERT(proxyVal);
+ return proxyVal->inst;
+ }
+ }
+ }
+ }
+ }
auto declRef = maybeCloneDeclRef(od->declRef);
return builder->getDeclRefVal(declRef);
}
@@ -3150,7 +3186,9 @@ namespace Slang
case kIROp_TypeType:
{
IRValue* od = (IRValue*)originalValue;
- return builder->getTypeVal(od->type);
+ int ioDiff = 0;
+ auto newType = od->type->SubstituteImpl(subst, &ioDiff);
+ return builder->getTypeVal(newType.As<Type>());
}
break;
default:
@@ -3207,7 +3245,9 @@ namespace Slang
newSubst->outer = cloneSubstitutions(context, subst->outer);
return newSubst;
}
- return nullptr;
+ else
+ SLANG_UNREACHABLE("unimplemented cloneSubstitution");
+ UNREACHABLE_RETURN(nullptr);
}
DeclRef<Decl> IRSpecContext::maybeCloneDeclRef(DeclRef<Decl> const& declRef)
@@ -3281,7 +3321,7 @@ namespace Slang
IRGlobalVar* cloneGlobalVar(IRSpecContext* context, IRGlobalVar* originalVar)
{
- auto clonedVar = context->builder->createGlobalVar(originalVar->getType()->getValueType());
+ auto clonedVar = context->builder->createGlobalVar(context->maybeCloneType(originalVar->getType()->getValueType()));
registerClonedValue(context, clonedVar, originalVar);
auto mangledName = originalVar->mangledName;
@@ -3703,10 +3743,67 @@ namespace Slang
}
}
+ // implementation provided in parameter-binding.cpp
+ RefPtr<ProgramLayout> specializeProgramLayout(
+ TargetRequest * targetReq,
+ ProgramLayout* programLayout,
+ Substitutions * typeSubst);
+
+ RefPtr<GlobalGenericParamSubstitution> createGlobalGenericParamSubstitution(
+ EntryPointRequest * entryPointRequest,
+ ProgramLayout * programLayout,
+ IRSpecContext* context,
+ IRModule* originalIRModule)
+ {
+ RefPtr<GlobalGenericParamSubstitution> globalParamSubst;
+ Substitutions * curTailSubst = nullptr;
+ for (auto param : programLayout->globalGenericParams)
+ {
+ auto paramSubst = new GlobalGenericParamSubstitution();
+ if (!globalParamSubst)
+ globalParamSubst = paramSubst;
+ if (curTailSubst)
+ curTailSubst->outer = paramSubst;
+ curTailSubst = paramSubst;
+ paramSubst->paramDecl = param->decl;
+ SLANG_ASSERT((UInt)param->index < entryPointRequest->genericParameterTypes.Count());
+ paramSubst->actualType = entryPointRequest->genericParameterTypes[param->index];
+ // find witness tables
+ for (auto witness : entryPointRequest->genericParameterWitnesses)
+ {
+ if (auto subtypeWitness = witness.As<SubtypeWitness>())
+ {
+ if (subtypeWitness->sub->EqualsVal(paramSubst->actualType))
+ {
+ auto witnessTableName = getMangledNameForConformanceWitness(subtypeWitness->sub, subtypeWitness->sup);
+ auto globalVar = originalIRModule->getFirstGlobalValue();
+ IRGlobalValue * table = nullptr;
+ while (globalVar)
+ {
+ if (globalVar->mangledName == witnessTableName)
+ {
+ table = globalVar;
+ break;
+ }
+ globalVar = globalVar->getNextValue();
+ }
+ SLANG_ASSERT(table);
+ table = cloneWitnessTable(context, (IRWitnessTable*)(table));
+ IRProxyVal * tableVal = new IRProxyVal();
+ tableVal->inst = table;
+ paramSubst->witnessTables.Add(KeyValuePair<RefPtr<Type>, RefPtr<Val>>(subtypeWitness->sup, tableVal));
+ }
+ }
+ }
+ }
+ return globalParamSubst;
+ }
+
IRModule* specializeIRForEntryPoint(
EntryPointRequest* entryPointRequest,
ProgramLayout* programLayout,
- CodeGenTarget target)
+ CodeGenTarget target,
+ TargetRequest* targetReq)
{
auto compileRequest = entryPointRequest->compileRequest;
auto session = compileRequest->mSession;
@@ -3720,8 +3817,6 @@ namespace Slang
return nullptr;
}
- auto entryPointLayout = findEntryPointLayout(programLayout, entryPointRequest);
-
// We now need to start cloning IR symbols from `originalIRModule`
// into a fresh IR module for this entry point. Along the way we need to:
//
@@ -3746,11 +3841,21 @@ namespace Slang
context->builder = &sharedContextStorage.builderStorage;
context->target = target;
+ // Create the GlobalGenericParamSubstitution for substituting global generic types
+ // into user-provided type arguments
+ auto globalParamSubst = createGlobalGenericParamSubstitution(entryPointRequest, programLayout, context, originalIRModule);
+
+ context->subst = globalParamSubst;
+
+ // now specailize the program layout using the substitution
+ RefPtr<ProgramLayout> newProgramLayout = specializeProgramLayout(targetReq, programLayout, globalParamSubst);
+
+ auto entryPointLayout = findEntryPointLayout(newProgramLayout, entryPointRequest);
// Next, we want to optimize lookup for layout infromation
// associated with global declarations, so that we can
// look things up based on the IR values (using mangled names)
- auto globalStructLayout = getGlobalStructLayout(programLayout);
+ auto globalStructLayout = getGlobalStructLayout(newProgramLayout);
for (auto globalVarLayout : globalStructLayout->fields)
{
String mangledName = getMangledName(globalVarLayout->varDecl);
diff --git a/source/slang/lookup.cpp b/source/slang/lookup.cpp
index b01732362..86bef3f4d 100644
--- a/source/slang/lookup.cpp
+++ b/source/slang/lookup.cpp
@@ -410,9 +410,9 @@ void lookUpMemberImpl(
if (auto declRefType = type->As<DeclRefType>())
{
auto declRef = declRefType->declRef;
- if (auto assocTypeDeclRef = declRef.As<AssocTypeDecl>())
+ if (declRef.As<AssocTypeDecl>() || declRef.As<GlobalGenericParamDecl>())
{
- for (auto constraintDeclRef : getMembersOfType<GenericTypeConstraintDecl>(assocTypeDeclRef))
+ for (auto constraintDeclRef : getMembersOfType<GenericTypeConstraintDecl>(declRef.As<ContainerDecl>()))
{
// The super-type in the constraint (e.g., `Foo` in `T : Foo`)
// will tell us a type we should use for lookup.
@@ -488,5 +488,4 @@ LookupResult lookUpMember(
return result;
}
-
-}
+} \ No newline at end of file
diff --git a/source/slang/lower-to-ir.cpp b/source/slang/lower-to-ir.cpp
index 326d25649..0f3e85805 100644
--- a/source/slang/lower-to-ir.cpp
+++ b/source/slang/lower-to-ir.cpp
@@ -3538,7 +3538,9 @@ static void lowerEntryPointToIR(
// the entry point request.
return;
}
-
+ // we need to lower all global type arguments as well
+ for (auto arg : entryPointRequest->genericParameterTypes)
+ lowerType(context, arg);
auto loweredEntryPointFunc = lowerDecl(context, entryPointFuncDecl);
}
diff --git a/source/slang/lower.cpp b/source/slang/lower.cpp
index b375fa80e..5a6603add 100644
--- a/source/slang/lower.cpp
+++ b/source/slang/lower.cpp
@@ -2870,6 +2870,13 @@ struct LoweringVisitor
UNREACHABLE_RETURN(LoweredDecl());
}
+ LoweredDecl visitGlobalGenericParamDecl(GlobalGenericParamDecl * /*decl*/)
+ {
+ // not supported
+ SLANG_UNREACHABLE("visitGlobalGenericParamDecl in LowerVisitor");
+ UNREACHABLE_RETURN(LoweredDecl());
+ }
+
LoweredDecl visitTypeDefDecl(TypeDefDecl* decl)
{
if (shared->target == CodeGenTarget::GLSL)
diff --git a/source/slang/parameter-binding.cpp b/source/slang/parameter-binding.cpp
index fa015186b..836ed254f 100644
--- a/source/slang/parameter-binding.cpp
+++ b/source/slang/parameter-binding.cpp
@@ -667,6 +667,17 @@ static void collectGlobalScopeGLSLVaryingParameter(
}
// Collect a single declaration into our set of parameters
+static void collectGlobalGenericParameter(
+ ParameterBindingContext* context,
+ RefPtr<GlobalGenericParamDecl> paramDecl)
+{
+ RefPtr<GenericParamLayout> layout = new GenericParamLayout();
+ layout->decl = paramDecl;
+ layout->index = (int)context->shared->programLayout->globalGenericParams.Count();
+ context->shared->programLayout->globalGenericParams.Add(layout);
+}
+
+// Collect a single declaration into our set of parameters
static void collectGlobalScopeParameter(
ParameterBindingContext* context,
RefPtr<VarDeclBase> varDecl)
@@ -1037,7 +1048,13 @@ static void completeBindingsForParameter(
continue;
}
-
+ else if (kind == LayoutResourceKind::GenericResource)
+ {
+ bindingInfo.space = 0;
+ bindingInfo.count = 0;
+ bindingInfo.index = 0;
+ continue;
+ }
// For now we only auto-generate bindings in space zero
//
@@ -1065,6 +1082,11 @@ static void completeBindingsForParameter(
bindingInfo.space = space;
}
+ if (firstTypeLayout->FindResourceInfo(LayoutResourceKind::GenericResource))
+ {
+
+ }
+
// At this point we should have explicit binding locations chosen for
// all the relevant resource kinds, so we can apply these to the
// declarations:
@@ -1093,15 +1115,22 @@ static void collectGlobalScopeParameters(
ModuleDecl* program)
{
// First enumerate parameters at global scope
- for( auto decl : program->Members )
+ // We collect two things here:
+ // 1. A shader parameter, which is always a variable
+ // 2. A global entry-point generic parameter type (`__generic_param`),
+ // which is a GlobalGenericParamDecl
+ // We collect global generic type parameters in the first pass,
+ // So we can fill in the correct index into ordinary type layouts
+ // for generic types in the second pass.
+ for (auto decl : program->Members)
{
- // A shader parameter is always a variable,
- // so skip declarations that aren't variables.
- auto varDecl = decl.As<VarDeclBase>();
- if (!varDecl)
- continue;
-
- collectGlobalScopeParameter(context, varDecl);
+ if (auto genParamDecl = decl.As<GlobalGenericParamDecl>())
+ collectGlobalGenericParameter(context, genParamDecl);
+ }
+ for (auto decl : program->Members)
+ {
+ if (auto varDecl = decl.As<VarDeclBase>())
+ collectGlobalScopeParameter(context, varDecl);
}
// Next, we need to enumerate the parameters of
@@ -1665,7 +1694,8 @@ void generateParameterBindings(
if (!layoutContext.rules)
return;
- RefPtr<ProgramLayout> programLayout = new ProgramLayout;
+ RefPtr<ProgramLayout> programLayout = new ProgramLayout();
+ targetReq->layout = programLayout;
// Create a context to hold shared state during the process
// of generating parameter bindings
@@ -1680,7 +1710,6 @@ void generateParameterBindings(
context.shared = &sharedContext;
context.translationUnit = nullptr;
context.layoutContext = layoutContext;
-
// Walk through AST to discover all the parameters
collectParameters(&context, compileReq);
@@ -1707,6 +1736,7 @@ void generateParameterBindings(
// If there are any global-scope uniforms, then we need to
// allocate a constant-buffer binding for them here.
ParameterBindingInfo globalConstantBufferBinding;
+ globalConstantBufferBinding.index = 0;
if( anyGlobalUniforms )
{
// TODO: this logic is only correct for D3D targets, where
@@ -1838,8 +1868,191 @@ void generateParameterBindings(
// We now have a bunch of layout information, which we should
// record into a suitable object that represents the program
- programLayout->globalScopeLayout = globalScopeLayout;
- targetReq->layout = programLayout;
+ RefPtr<VarLayout> globalVarLayout = new VarLayout();
+ globalVarLayout->typeLayout = globalScopeLayout;
+ if (anyGlobalUniforms)
+ {
+ auto cbInfo = globalVarLayout->findOrAddResourceInfo(LayoutResourceKind::ConstantBuffer);
+ cbInfo->space = 0;
+ cbInfo->index = globalConstantBufferBinding.index;
+ }
+ programLayout->globalScopeLayout = globalVarLayout;
}
+StructTypeLayout* getGlobalStructLayout(
+ ProgramLayout* programLayout);
+
+RefPtr<ProgramLayout> specializeProgramLayout(
+ TargetRequest * targetReq,
+ ProgramLayout* programLayout,
+ Substitutions * typeSubst)
+{
+ RefPtr<ProgramLayout> newProgramLayout;
+ newProgramLayout = new ProgramLayout();
+ newProgramLayout->bindingForHackSampler = programLayout->bindingForHackSampler;
+ newProgramLayout->hackSamplerVar = programLayout->hackSamplerVar;
+ for (auto & entryPoint : programLayout->entryPoints)
+ {
+ RefPtr<EntryPointLayout> newEntryPoint = new EntryPointLayout(*entryPoint);
+ // TODO: for now just copy existing entry point layouts, but we eventually need to
+ // specialize these as well...
+ newProgramLayout->entryPoints.Add(newEntryPoint);
+ }
+
+ List<RefPtr<TypeLayout>> paramTypeLayouts;
+ auto globalStructLayout = getGlobalStructLayout(programLayout);
+ SLANG_ASSERT(globalStructLayout);
+ RefPtr<StructTypeLayout> structLayout = new StructTypeLayout();
+ RefPtr<TypeLayout> globalScopeLayout = structLayout;
+ structLayout->uniformAlignment = globalStructLayout->uniformAlignment;
+
+ // Try to find rules based on the selected code-generation target
+ auto layoutContext = getInitialLayoutContextForTarget(targetReq);
+
+ // If there was no target, or there are no rules for the target,
+ // then bail out here.
+ if (!layoutContext.rules)
+ return newProgramLayout;
+
+
+ // we need to initialize a layout context to mark used registers
+ SharedParameterBindingContext sharedContext;
+ sharedContext.compileRequest = targetReq->compileRequest;
+ sharedContext.defaultLayoutRules = layoutContext.getRulesFamily();
+ sharedContext.programLayout = programLayout;
+
+ // Create a sub-context to collect parameters that get
+ // declared into the global scope
+ ParameterBindingContext context;
+ context.shared = &sharedContext;
+ context.translationUnit = nullptr;
+ context.layoutContext = layoutContext;
+
+ auto constantBufferRules = context.getRulesFamily()->getConstantBufferRules();
+ structLayout->rules = constantBufferRules;
+
+ UniformLayoutInfo structLayoutInfo;
+ structLayoutInfo.alignment = globalStructLayout->uniformAlignment;
+ structLayoutInfo.size = 0;
+ bool anyUniforms = false;
+ Dictionary<RefPtr<VarLayout>, RefPtr<VarLayout>> varLayoutMapping;
+ for (auto & varLayout : globalStructLayout->fields)
+ {
+ // To recover layout context, we skip generic resources in the first pass
+ // If the var is a generic resource, its resourceInfos will be empty.
+ if (varLayout->resourceInfos.Count() == 0)
+ continue;
+ SLANG_ASSERT(varLayout->resourceInfos.Count() == varLayout->typeLayout->resourceInfos.Count());
+ auto uniformInfo = varLayout->FindResourceInfo(LayoutResourceKind::Uniform);
+ auto tUniformInfo = varLayout->typeLayout->FindResourceInfo(LayoutResourceKind::Uniform);
+ if (uniformInfo)
+ {
+ anyUniforms = true;
+ SLANG_ASSERT(tUniformInfo);
+ structLayoutInfo.size = Math::Max(structLayoutInfo.size, uniformInfo->index + tUniformInfo->count);
+ }
+ for (UInt i = 0; i < varLayout->resourceInfos.Count(); i++)
+ {
+ auto resInfo = varLayout->resourceInfos[i];
+ auto tresInfo = varLayout->typeLayout->resourceInfos[i];
+ SLANG_ASSERT(resInfo.kind == tresInfo.kind);
+ auto usedRangeSet = findUsedRangeSetForSpace(&context, resInfo.space);
+ markSpaceUsed(&context, resInfo.space);
+ usedRangeSet->usedResourceRanges[(int)resInfo.kind].Add(
+ nullptr, // we don't need to track parameter info here
+ resInfo.index,
+ resInfo.index + varLayout->typeLayout->resourceInfos[0].count);
+ }
+ structLayout->fields.Add(varLayout);
+ varLayoutMapping[varLayout] = varLayout;
+ }
+ auto originalGlobalCBufferInfo = programLayout->globalScopeLayout->FindResourceInfo(LayoutResourceKind::ConstantBuffer);
+ VarLayout::ResourceInfo globalCBufferInfo;
+ globalCBufferInfo.kind = LayoutResourceKind::None;
+ globalCBufferInfo.space = 0;
+ globalCBufferInfo.index = 0;
+ if (originalGlobalCBufferInfo)
+ {
+ globalCBufferInfo.kind = LayoutResourceKind::ConstantBuffer;
+ globalCBufferInfo.space = originalGlobalCBufferInfo->space;
+ globalCBufferInfo.index = originalGlobalCBufferInfo->index;
+ }
+ // we have the context restored, can continue to layout the generic variables now
+ for (auto & varLayout : globalStructLayout->fields)
+ {
+ if (varLayout->typeLayout->FindResourceInfo(LayoutResourceKind::GenericResource))
+ {
+ RefPtr<Type> newType = varLayout->typeLayout->type->Substitute(typeSubst).As<Type>();
+ RefPtr<TypeLayout> newTypeLayout = CreateTypeLayout(
+ layoutContext.with(constantBufferRules),
+ newType);
+ auto layoutInfo = newTypeLayout->FindResourceInfo(LayoutResourceKind::Uniform);
+ size_t uniformSize = layoutInfo ? layoutInfo->count : 0;
+ if (uniformSize)
+ {
+ if (globalCBufferInfo.kind == LayoutResourceKind::None)
+ {
+ // user defined a uniform via a global generic type argument
+ // but we have not reserved a binding for the global uniform buffer
+ UInt space = 0;
+ auto usedRangeSet = findUsedRangeSetForSpace(&context, space);
+ globalCBufferInfo.kind = LayoutResourceKind::ConstantBuffer;
+ globalCBufferInfo.index =
+ usedRangeSet->usedResourceRanges[
+ (int)LayoutResourceKind::ConstantBuffer].Allocate(nullptr, 1);
+ globalCBufferInfo.space = space;
+ }
+ }
+ RefPtr<VarLayout> newVarLayout = new VarLayout();
+ RefPtr<ParameterInfo> paramInfo = new ParameterInfo();
+ newVarLayout->varDecl = varLayout->varDecl;
+ newVarLayout->typeLayout = newTypeLayout;
+ paramInfo->varLayouts.Add(newVarLayout);
+ completeBindingsForParameter(&context, paramInfo);
+ // update uniform layout
+
+ if (uniformSize != 0)
+ {
+ // Make sure uniform fields get laid out properly...
+ UniformLayoutInfo fieldInfo(
+ uniformSize,
+ newTypeLayout->uniformAlignment);
+ size_t uniformOffset = layoutContext.getRulesFamily()->getConstantBufferRules()->AddStructField(
+ &structLayoutInfo,
+ fieldInfo);
+ newVarLayout->findOrAddResourceInfo(LayoutResourceKind::Uniform)->index = uniformOffset;
+ anyUniforms = true;
+ }
+ structLayout->fields.Add(newVarLayout);
+ varLayoutMapping[varLayout] = newVarLayout;
+ }
+ }
+ for (auto mapping : globalStructLayout->mapVarToLayout)
+ {
+ RefPtr<VarLayout> updatedVarLayout = mapping.Value;
+ varLayoutMapping.TryGetValue(updatedVarLayout, updatedVarLayout);
+ structLayout->mapVarToLayout[mapping.Key] = updatedVarLayout;
+ }
+
+ // If there are global-scope uniforms, then we need to wrap
+ // up a global constant buffer type layout to hold them
+ RefPtr<VarLayout> globalVarLayout = new VarLayout();
+ if (anyUniforms)
+ {
+ auto globalConstantBufferLayout = createParameterGroupTypeLayout(
+ layoutContext,
+ nullptr,
+ constantBufferRules,
+ constantBufferRules->GetObjectLayout(ShaderParameterKind::ConstantBuffer),
+ structLayout);
+
+ globalScopeLayout = globalConstantBufferLayout;
+ auto cbInfo = globalVarLayout->findOrAddResourceInfo(LayoutResourceKind::ConstantBuffer);
+ *cbInfo = globalCBufferInfo;
+ }
+ globalVarLayout->typeLayout = globalScopeLayout;
+ programLayout->globalScopeLayout = globalVarLayout;
+ newProgramLayout->globalScopeLayout = globalVarLayout;
+ return newProgramLayout;
+}
}
diff --git a/source/slang/parser.cpp b/source/slang/parser.cpp
index 42c763099..0a4360e3f 100644
--- a/source/slang/parser.cpp
+++ b/source/slang/parser.cpp
@@ -95,9 +95,9 @@ namespace Slang
RefPtr<StructDecl> ParseStruct();
RefPtr<ClassDecl> ParseClass();
RefPtr<Stmt> ParseStatement();
- RefPtr<Stmt> ParseBlockStatement();
- RefPtr<DeclStmt> ParseVarDeclrStatement(Modifiers modifiers);
- RefPtr<IfStmt> ParseIfStatement();
+ RefPtr<Stmt> parseBlockStatement();
+ RefPtr<DeclStmt> parseVarDeclrStatement(Modifiers modifiers);
+ RefPtr<IfStmt> parseIfStatement();
RefPtr<ForStmt> ParseForStatement();
RefPtr<WhileStmt> ParseWhileStatement();
RefPtr<DoWhileStmt> ParseDoWhileStatement();
@@ -1034,7 +1034,7 @@ namespace Slang
}
else
{
- decl->Body = parser->ParseBlockStatement();
+ decl->Body = parser->parseBlockStatement();
}
parser->PopScope();
@@ -2172,41 +2172,55 @@ namespace Slang
}
}
- RefPtr<RefObject> ParseAssocType(Parser * parser, void *)
+ void parseOptionalGenericConstraints(Parser * parser, ContainerDecl* decl)
{
- RefPtr<AssocTypeDecl> assocTypeDecl = new AssocTypeDecl();
-
- auto nameToken = parser->ReadToken(TokenType::Identifier);
- assocTypeDecl->nameAndLoc = NameLoc(nameToken);
- assocTypeDecl->loc = nameToken.loc;
if (AdvanceIf(parser, TokenType::Colon))
{
- while (!parser->tokenReader.IsAtEnd())
+ do
{
- auto paramConstraint = new GenericTypeConstraintDecl();
+ RefPtr<GenericTypeConstraintDecl> paramConstraint = new GenericTypeConstraintDecl();
parser->FillPosition(paramConstraint);
- auto paramType = DeclRefType::Create(
+ RefPtr<DeclRefType> paramType = DeclRefType::Create(
parser->getSession(),
- DeclRef<Decl>(assocTypeDecl, nullptr));
+ DeclRef<Decl>(decl, nullptr));
- auto paramTypeExpr = new SharedTypeExpr();
- paramTypeExpr->loc = assocTypeDecl->loc;
+ RefPtr<SharedTypeExpr> paramTypeExpr = new SharedTypeExpr();
+ paramTypeExpr->loc = decl->loc;
paramTypeExpr->base.type = paramType;
paramTypeExpr->type = QualType(getTypeType(paramType));
paramConstraint->sub = TypeExp(paramTypeExpr);
paramConstraint->sup = parser->ParseTypeExp();
- AddMember(assocTypeDecl, paramConstraint);
- if (!AdvanceIf(parser, TokenType::Comma))
- break;
- }
+ AddMember(decl, paramConstraint);
+ } while (AdvanceIf(parser, TokenType::Comma));
}
+ }
+
+ RefPtr<RefObject> parseAssocType(Parser * parser, void *)
+ {
+ RefPtr<AssocTypeDecl> assocTypeDecl = new AssocTypeDecl();
+
+ auto nameToken = parser->ReadToken(TokenType::Identifier);
+ assocTypeDecl->nameAndLoc = NameLoc(nameToken);
+ assocTypeDecl->loc = nameToken.loc;
+ parseOptionalGenericConstraints(parser, assocTypeDecl);
parser->ReadToken(TokenType::Semicolon);
return assocTypeDecl;
}
+ RefPtr<RefObject> parseGlobalGenericParamDecl(Parser * parser, void *)
+ {
+ RefPtr<GlobalGenericParamDecl> genParamDecl = new GlobalGenericParamDecl();
+ auto nameToken = parser->ReadToken(TokenType::Identifier);
+ genParamDecl->nameAndLoc = NameLoc(nameToken);
+ genParamDecl->loc = nameToken.loc;
+ parseOptionalGenericConstraints(parser, genParamDecl);
+ parser->ReadToken(TokenType::Semicolon);
+ return genParamDecl;
+ }
+
static RefPtr<RefObject> parseInterfaceDecl(Parser* parser, void* /*userData*/)
{
RefPtr<InterfaceDecl> decl = new InterfaceDecl();
@@ -2220,7 +2234,7 @@ namespace Slang
return decl;
}
- static RefPtr<RefObject> ParseConstructorDecl(Parser* parser, void* /*userData*/)
+ static RefPtr<RefObject> parseConstructorDecl(Parser* parser, void* /*userData*/)
{
RefPtr<ConstructorDecl> decl = new ConstructorDecl();
parser->FillPosition(decl.Ptr());
@@ -2243,7 +2257,7 @@ namespace Slang
}
else
{
- decl->Body = parser->ParseBlockStatement();
+ decl->Body = parser->parseBlockStatement();
}
return decl;
}
@@ -2271,7 +2285,7 @@ namespace Slang
if( parser->tokenReader.PeekTokenType() == TokenType::LBrace )
{
- decl->Body = parser->ParseBlockStatement();
+ decl->Body = parser->parseBlockStatement();
}
else
{
@@ -2664,7 +2678,7 @@ namespace Slang
parser->ReadToken(TokenType::LParent);
stmt->condition = parser->ParseExpression();
parser->ReadToken(TokenType::RParent);
- stmt->body = parser->ParseBlockStatement();
+ stmt->body = parser->parseBlockStatement();
return stmt;
}
@@ -2788,11 +2802,11 @@ namespace Slang
RefPtr<Stmt> statement;
if (LookAheadToken(TokenType::LBrace))
- statement = ParseBlockStatement();
+ statement = parseBlockStatement();
else if (peekTypeName(this))
- statement = ParseVarDeclrStatement(modifiers);
+ statement = parseVarDeclrStatement(modifiers);
else if (LookAheadToken("if"))
- statement = ParseIfStatement();
+ statement = parseIfStatement();
else if (LookAheadToken("for"))
statement = ParseForStatement();
else if (LookAheadToken("while"))
@@ -2852,7 +2866,7 @@ namespace Slang
// Note: the declaration will consume any modifiers
// that had been in place on the statement.
tokenReader.mCursor = startPos;
- statement = ParseVarDeclrStatement(modifiers);
+ statement = parseVarDeclrStatement(modifiers);
return statement;
}
@@ -2885,7 +2899,7 @@ namespace Slang
return statement;
}
- RefPtr<Stmt> Parser::ParseBlockStatement()
+ RefPtr<Stmt> Parser::parseBlockStatement()
{
// If we are being asked not to check things *and* we haven't
// seen any `import` declarations yet, then we can safely assume
@@ -2983,7 +2997,7 @@ namespace Slang
return blockStatement;
}
- RefPtr<DeclStmt> Parser::ParseVarDeclrStatement(
+ RefPtr<DeclStmt> Parser::parseVarDeclrStatement(
Modifiers modifiers)
{
RefPtr<DeclStmt>varDeclrStatement = new DeclStmt();
@@ -2994,7 +3008,7 @@ namespace Slang
return varDeclrStatement;
}
- RefPtr<IfStmt> Parser::ParseIfStatement()
+ RefPtr<IfStmt> Parser::parseIfStatement()
{
RefPtr<IfStmt> ifStatement = new IfStmt();
FillPosition(ifStatement.Ptr());
@@ -3045,7 +3059,7 @@ namespace Slang
ReadToken(TokenType::LParent);
if (peekTypeName(this))
{
- stmt->InitialStatement = ParseVarDeclrStatement(Modifiers());
+ stmt->InitialStatement = parseVarDeclrStatement(Modifiers());
}
else
{
@@ -3107,7 +3121,7 @@ namespace Slang
return breakStatement;
}
- RefPtr<ContinueStmt> Parser::ParseContinueStatement()
+ RefPtr<ContinueStmt> Parser::ParseContinueStatement()
{
RefPtr<ContinueStmt> continueStatement = new ContinueStmt();
FillPosition(continueStatement.Ptr());
@@ -4128,17 +4142,18 @@ namespace Slang
// Add syntax for declaration keywords
#define DECL(KEYWORD, CALLBACK) \
addBuiltinSyntax<Decl>(session, scope, #KEYWORD, &CALLBACK)
- DECL(typedef, ParseTypeDef);
- DECL(associatedtype,ParseAssocType);
- DECL(cbuffer, parseHLSLCBufferDecl);
- DECL(tbuffer, parseHLSLTBufferDecl);
- DECL(__generic, ParseGenericDecl);
- DECL(__extension, ParseExtensionDecl);
- DECL(__init, ParseConstructorDecl);
- DECL(__subscript, ParseSubscriptDecl);
- DECL(interface, parseInterfaceDecl);
- DECL(syntax, parseSyntaxDecl);
- DECL(__import, parseImportDecl);
+ DECL(typedef, ParseTypeDef);
+ DECL(associatedtype, parseAssocType);
+ DECL(__generic_param, parseGlobalGenericParamDecl);
+ DECL(cbuffer, parseHLSLCBufferDecl);
+ DECL(tbuffer, parseHLSLTBufferDecl);
+ DECL(__generic, ParseGenericDecl);
+ DECL(__extension, ParseExtensionDecl);
+ DECL(__init, parseConstructorDecl);
+ DECL(__subscript, ParseSubscriptDecl);
+ DECL(interface, parseInterfaceDecl);
+ DECL(syntax, parseSyntaxDecl);
+ DECL(__import, parseImportDecl);
#undef DECL
diff --git a/source/slang/reflection.cpp b/source/slang/reflection.cpp
index 9fc032c76..14199f126 100644
--- a/source/slang/reflection.cpp
+++ b/source/slang/reflection.cpp
@@ -886,3 +886,22 @@ SLANG_API SlangReflectionEntryPoint* spReflection_getEntryPointByIndex(SlangRefl
return convert(program->entryPoints[(int) index].Ptr());
}
+
+SLANG_API SlangUInt spReflection_getGlobalConstantBufferBinding(SlangReflection* inProgram)
+{
+ auto program = convert(inProgram);
+ if (!program) return 0;
+ auto cb = program->globalScopeLayout->FindResourceInfo(LayoutResourceKind::ConstantBuffer);
+ if (!cb) return 0;
+ return cb->index;
+}
+
+SLANG_API size_t spReflection_getGlobalConstantBufferSize(SlangReflection* inProgram)
+{
+ auto program = convert(inProgram);
+ if (!program) return 0;
+ auto structLayout = getGlobalStructLayout(program);
+ auto uniform = structLayout->FindResourceInfo(LayoutResourceKind::Uniform);
+ if (!uniform) return 0;
+ return uniform->count;
+}
diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp
index 17f8ea96d..6a103fc2d 100644
--- a/source/slang/slang.cpp
+++ b/source/slang/slang.cpp
@@ -400,14 +400,16 @@ void CompileRequest::addTranslationUnitSourceFile(
int CompileRequest::addEntryPoint(
int translationUnitIndex,
String const& name,
- Profile entryPointProfile)
+ Profile entryPointProfile,
+ List<String> const & genericTypeNames)
{
RefPtr<EntryPointRequest> entryPoint = new EntryPointRequest();
entryPoint->compileRequest = this;
entryPoint->name = getNamePool()->getName(name);
entryPoint->profile = entryPointProfile;
entryPoint->translationUnitIndex = translationUnitIndex;
-
+ for (auto typeName : genericTypeNames)
+ entryPoint->genericParameterTypeNames.Add(getNamePool()->getName(typeName));
auto translationUnit = translationUnits[translationUnitIndex].Ptr();
translationUnit->entryPoints.Add(entryPoint);
@@ -909,7 +911,31 @@ SLANG_API int spAddEntryPoint(
return req->addEntryPoint(
translationUnitIndex,
name,
- Slang::Profile(Slang::Profile::RawVal(profile)));
+ Slang::Profile(Slang::Profile::RawVal(profile)),
+ Slang::List<Slang::String>());
+}
+
+SLANG_API int spAddEntryPointEx(
+ SlangCompileRequest* request,
+ int translationUnitIndex,
+ char const* name,
+ SlangProfileID profile,
+ int genericParamTypeNameCount,
+ char const ** genericParamTypeNames)
+{
+ if (!request) return -1;
+ auto req = REQ(request);
+ if (!name) return -1;
+ if (translationUnitIndex < 0) return -1;
+ if (Slang::UInt(translationUnitIndex) >= req->translationUnits.Count()) return -1;
+ Slang::List<Slang::String> typeNames;
+ for (int i = 0; i < genericParamTypeNameCount; i++)
+ typeNames.Add(genericParamTypeNames[i]);
+ return req->addEntryPoint(
+ translationUnitIndex,
+ name,
+ Slang::Profile(Slang::Profile::RawVal(profile)),
+ typeNames);
}
diff --git a/source/slang/syntax-base-defs.h b/source/slang/syntax-base-defs.h
index 3c7e8c5ae..fdb2694a9 100644
--- a/source/slang/syntax-base-defs.h
+++ b/source/slang/syntax-base-defs.h
@@ -197,6 +197,38 @@ SYNTAX_CLASS(ThisTypeSubstitution, Substitutions)
)
END_SYNTAX_CLASS()
+SYNTAX_CLASS(GlobalGenericParamSubstitution, Substitutions)
+ // the __generic_param decl to be substituted
+ DECL_FIELD(GlobalGenericParamDecl*, paramDecl)
+ // the actual type to substitute in
+ SYNTAX_FIELD(RefPtr<Val>, actualType)
+
+RAW(
+ // Apply a set of substitutions to the bindings in this substitution
+ virtual RefPtr<Substitutions> SubstituteImpl(Substitutions* subst, int* ioDiff) override;
+
+ // Check if these are equivalent substitutiosn to another set
+ virtual bool Equals(Substitutions* subst) override;
+ virtual bool operator == (const Substitutions & subst) override
+ {
+ return Equals(const_cast<Substitutions*>(&subst));
+ }
+ virtual int GetHashCode() const override
+ {
+ int rs = actualType->GetHashCode();
+ for (auto && v : witnessTables)
+ {
+ rs = combineHash(rs, v.Key->GetHashCode());
+ rs = combineHash(rs, v.Value->GetHashCode());
+ }
+ return rs;
+ }
+ typedef List<KeyValuePair<RefPtr<Type>, RefPtr<Val>>> WitnessTableLookupTable;
+)
+ // The witness tables for each interface this actual type implements
+ SYNTAX_FIELD(WitnessTableLookupTable, witnessTables)
+END_SYNTAX_CLASS()
+
ABSTRACT_SYNTAX_CLASS(SyntaxNode, SyntaxNodeBase)
END_SYNTAX_CLASS()
diff --git a/source/slang/syntax.cpp b/source/slang/syntax.cpp
index e5fc8dfa3..fa9c88051 100644
--- a/source/slang/syntax.cpp
+++ b/source/slang/syntax.cpp
@@ -93,6 +93,7 @@ ABSTRACT_SYNTAX_CLASS(Expr, SyntaxNode);
ABSTRACT_SYNTAX_CLASS(Substitutions, SyntaxNode);
ABSTRACT_SYNTAX_CLASS(GenericSubstitution, Substitutions);
ABSTRACT_SYNTAX_CLASS(ThisTypeSubstitution, Substitutions);
+ABSTRACT_SYNTAX_CLASS(GlobalGenericParamSubstitution, Substitutions);
#include "expr-defs.h"
#include "decl-defs.h"
@@ -488,6 +489,20 @@ void Type::accept(IValVisitor* visitor, void* extra)
}
}
}
+ else if (auto globalGenParam = dynamic_cast<GlobalGenericParamDecl*>(declRef.getDecl()))
+ {
+ // search for a substitution that might apply to us
+ for (auto s = subst; s; s = s->outer.Ptr())
+ {
+ if (auto genericSubst = dynamic_cast<GlobalGenericParamSubstitution*>(s))
+ {
+ if (genericSubst->paramDecl == globalGenParam)
+ {
+ return genericSubst->actualType;
+ }
+ }
+ }
+ }
int diff = 0;
DeclRef<Decl> substDeclRef = declRef.SubstituteImpl(subst, &diff);
@@ -1208,6 +1223,35 @@ void Type::accept(IValVisitor* visitor, void* extra)
return false;
}
+ RefPtr<Substitutions> GlobalGenericParamSubstitution::SubstituteImpl(Substitutions* /*subst*/, int* /*ioDiff*/)
+ {
+ // we will never replace values for this type of substitution
+ return this;
+ }
+
+ bool GlobalGenericParamSubstitution::Equals(Substitutions* subst)
+ {
+ if (!subst)
+ return false;
+ if (auto genSubst = dynamic_cast<GlobalGenericParamSubstitution*>(subst))
+ {
+ if (paramDecl != genSubst->paramDecl)
+ return false;
+ if (!actualType->EqualsVal(genSubst->actualType))
+ return false;
+ if (witnessTables.Count() != genSubst->witnessTables.Count())
+ return false;
+ for (UInt i = 0; i < witnessTables.Count(); i++)
+ {
+ if (!witnessTables[i].Key->Equals(genSubst->witnessTables[i].Key))
+ return false;
+ if (!witnessTables[i].Value->EqualsVal(genSubst->witnessTables[i].Value))
+ return false;
+ }
+ return true;
+ }
+ return false;
+ }
// DeclRefBase
@@ -1564,6 +1608,24 @@ void Type::accept(IValVisitor* visitor, void* extra)
return genericSubst->args[index + ordinaryParamCount];
}
}
+ else if (auto globalGenParamSubst = dynamic_cast<GlobalGenericParamSubstitution*>(s))
+ {
+ // we have a GlobalGenericParamSubstitution, this substitution will provide
+ // a concrete IRWitnessTable for a generic global variable
+ auto supType = GetSup(genConstraintDecl);
+
+ // check if the substitution is really about this global generic type parameter
+ if (globalGenParamSubst->paramDecl != genConstraintDecl.getDecl()->ParentDecl)
+ continue;
+
+ // find witness table for the required interface
+ for (auto witness : globalGenParamSubst->witnessTables)
+ if (witness.Key->EqualsVal(supType))
+ {
+ (*ioDiff)++;
+ return witness.Value;
+ }
+ }
}
}
RefPtr<DeclaredSubtypeWitness> rs = new DeclaredSubtypeWitness();
diff --git a/source/slang/syntax.h b/source/slang/syntax.h
index 46beca2d9..b4d550ef5 100644
--- a/source/slang/syntax.h
+++ b/source/slang/syntax.h
@@ -1073,7 +1073,7 @@ namespace Slang
{
return declRef.Substitute(declRef.getDecl()->base.type);
}
-
+
inline RefPtr<Type> GetType(DeclRef<TypeDefDecl> const& declRef)
{
return declRef.Substitute(declRef.getDecl()->type.Ptr());
diff --git a/source/slang/type-layout.cpp b/source/slang/type-layout.cpp
index 8fa790dd8..30b2ee01a 100644
--- a/source/slang/type-layout.cpp
+++ b/source/slang/type-layout.cpp
@@ -1222,6 +1222,11 @@ SimpleLayoutInfo GetLayoutImpl(
return GetLayoutImpl(subContext, type, outTypeLayout, SimpleLayoutInfo());
}
+int findGenericParam(List<RefPtr<GenericParamLayout>> & genericParameters, GlobalGenericParamDecl * decl)
+{
+ return (int)genericParameters.FindFirst([=](RefPtr<GenericParamLayout> & x) {return x->decl.Ptr() == decl; });
+}
+
SimpleLayoutInfo GetLayoutImpl(
TypeLayoutContext const& context,
Type* type,
@@ -1599,6 +1604,25 @@ SimpleLayoutInfo GetLayoutImpl(
return info;
}
+ else if (auto globalGenParam = declRef.As<GlobalGenericParamDecl>())
+ {
+ SimpleLayoutInfo info;
+ info.alignment = 0;
+ info.size = 0;
+ info.kind = LayoutResourceKind::GenericResource;
+ if (outTypeLayout)
+ {
+ auto genParamTypeLayout = new GenericParamTypeLayout();
+ // we should have already populated ProgramLayout::genericEntryPointParams list at this point,
+ // so we can find the index of this generic param decl in the list
+ genParamTypeLayout->type = type;
+ genParamTypeLayout->paramIndex = findGenericParam(context.targetReq->layout->globalGenericParams, genParamTypeLayout->getGlobalGenericParamDecl());
+ genParamTypeLayout->rules = rules;
+ genParamTypeLayout->findOrAddResourceInfo(LayoutResourceKind::GenericResource)->count++;
+ *outTypeLayout = genParamTypeLayout;
+ }
+ return info;
+ }
}
else if (auto errorType = type->As<ErrorType>())
{
@@ -1667,4 +1691,12 @@ RefPtr<TypeLayout> CreateTypeLayout(
return CreateTypeLayout(context, type, SimpleLayoutInfo());
}
+RefPtr<GlobalGenericParamDecl> GenericParamTypeLayout::getGlobalGenericParamDecl()
+{
+ auto declRefType = type->AsDeclRefType();
+ SLANG_ASSERT(declRefType);
+ auto rsDeclRef = declRefType->declRef.As<GlobalGenericParamDecl>();
+ return rsDeclRef.getDecl();
+}
+
} // namespace Slang
diff --git a/source/slang/type-layout.h b/source/slang/type-layout.h
index 363b01486..4ce6dc355 100644
--- a/source/slang/type-layout.h
+++ b/source/slang/type-layout.h
@@ -220,7 +220,7 @@ typedef unsigned int VarLayoutFlags;
enum VarLayoutFlag : VarLayoutFlags
{
IsRedeclaration = 1 << 0, ///< This is a redeclaration of some shader parameter
- HasSemantic = 1 << 1,
+ HasSemantic = 1 << 1
};
// A reified layout for a particular variable, field, etc.
@@ -358,6 +358,13 @@ public:
Dictionary<Decl*, RefPtr<VarLayout>> mapVarToLayout;
};
+class GenericParamTypeLayout : public TypeLayout
+{
+public:
+ RefPtr<GlobalGenericParamDecl> getGlobalGenericParamDecl();
+ int paramIndex = 0;
+};
+
// Layout information for a single shader entry point
// within a program
//
@@ -386,6 +393,13 @@ public:
unsigned flags = 0;
};
+class GenericParamLayout : public Layout
+{
+public:
+ RefPtr<GlobalGenericParamDecl> decl;
+ int index;
+};
+
// Layout information for the global scope of a program
class ProgramLayout : public Layout
{
@@ -403,13 +417,15 @@ public:
// (since a constant buffer will have to be allocated
// to store them).
//
- RefPtr<TypeLayout> globalScopeLayout;
+ RefPtr<VarLayout> globalScopeLayout;
// We catalog the requested entry points here,
// and any entry-point-specific parameter data
// will (eventually) belong there...
List<RefPtr<EntryPointLayout>> entryPoints;
+ List<RefPtr<GenericParamLayout>> globalGenericParams;
+
// HACK: binding to use when we have to create
// a dummy sampler just to appease glslang
int bindingForHackSampler = 0;
diff --git a/tests/compute/global-type-param.slang b/tests/compute/global-type-param.slang
new file mode 100644
index 000000000..301ef1021
--- /dev/null
+++ b/tests/compute/global-type-param.slang
@@ -0,0 +1,30 @@
+//TEST(smoke,compute):COMPARE_COMPUTE:-xslang -use-ir
+//TEST_INPUT:ubuffer(data=[0], stride=4):dxbinding(0),glbinding(0),out
+//TEST_INPUT:type Impl
+
+RWStructuredBuffer<float> outputBuffer;
+
+interface IBase
+{
+ float compute();
+}
+
+struct Impl : IBase
+{
+ float compute()
+ {
+ return 1.0;
+ }
+};
+
+__generic_param TImpl : IBase;
+
+TImpl impl;
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ uint tid = dispatchThreadID.x;
+ float outVal = impl.compute();
+ outputBuffer[tid] = outVal;
+} \ No newline at end of file
diff --git a/tests/compute/global-type-param.slang.expected.txt b/tests/compute/global-type-param.slang.expected.txt
new file mode 100644
index 000000000..47b9ba0c8
--- /dev/null
+++ b/tests/compute/global-type-param.slang.expected.txt
@@ -0,0 +1 @@
+3F800000 \ No newline at end of file
diff --git a/tests/compute/global-type-param1.slang b/tests/compute/global-type-param1.slang
new file mode 100644
index 000000000..c9b754aa3
--- /dev/null
+++ b/tests/compute/global-type-param1.slang
@@ -0,0 +1,46 @@
+//TEST(smoke,compute):COMPARE_COMPUTE:-xslang -use-ir
+//TEST_INPUT: cbuffer(data=[0.5 0 0 0 1.0], stride=4):dxbinding(0),glbinding(0)
+//TEST_INPUT: cbuffer(data=[1.0], stride=4):dxbinding(1),glbinding(1)
+//TEST_INPUT: Texture2D(size=4, content = zero) : dxbinding(0),glbinding(0)
+//TEST_INPUT: Texture2D(size=4, content = one) : dxbinding(1),glbinding(1)
+//TEST_INPUT: Sampler : dxbinding(0),glbinding(0,1,2,3,4,5,6)
+//TEST_INPUT: Sampler : dxbinding(1),glbinding(0,1,2,3,4,5,6)
+//TEST_INPUT: ubuffer(data=[0], stride=4):dxbinding(0),glbinding(0),out
+//TEST_INPUT: type Impl
+
+RWStructuredBuffer<float> outputBuffer;
+
+interface IBase
+{
+ float compute();
+}
+
+struct Impl : IBase
+{
+ float base; // = 1.0
+ Texture2D tex;
+ SamplerState sampler;
+ float compute()
+ {
+ return tex.SampleLevel(sampler, float2(0.0), 0.0).x + base;
+ }
+};
+
+__generic_param TImpl : IBase;
+
+TImpl impl;
+
+float base0; // = 0.5
+
+Texture2D tex1; // = 0.0
+SamplerState sampler;
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ uint tid = dispatchThreadID.x;
+ float b0 = tex1.SampleLevel(sampler, float2(0.0), 0.0).x + base0; // = 0.5
+ float outVal = impl.compute(); // = 2.0
+ outVal = b0 / outVal; // = 0.25
+ outputBuffer[tid] = outVal;
+} \ No newline at end of file
diff --git a/tests/compute/global-type-param1.slang.expected.txt b/tests/compute/global-type-param1.slang.expected.txt
new file mode 100644
index 000000000..4846e7be2
--- /dev/null
+++ b/tests/compute/global-type-param1.slang.expected.txt
@@ -0,0 +1 @@
+3E800000
diff --git a/tests/compute/global-type-param2.slang b/tests/compute/global-type-param2.slang
new file mode 100644
index 000000000..b54f4c430
--- /dev/null
+++ b/tests/compute/global-type-param2.slang
@@ -0,0 +1,61 @@
+//TEST(smoke,compute):COMPARE_COMPUTE:-xslang -use-ir
+//TEST_INPUT: cbuffer(data=[0.5 0 0 0], stride=4):dxbinding(0),glbinding(0)
+//TEST_INPUT: cbuffer(data=[1.0], stride=4):dxbinding(1),glbinding(1)
+//TEST_INPUT: Texture2D(size=4, content = zero) : dxbinding(0),glbinding(0)
+//TEST_INPUT: Texture2D(size=4, content = one) : dxbinding(1),glbinding(1)
+//TEST_INPUT: Sampler : dxbinding(0),glbinding(0,1,2,3,4,5,6)
+//TEST_INPUT: Sampler : dxbinding(1),glbinding(0,1,2,3,4,5,6)
+//TEST_INPUT: ubuffer(data=[0], stride=4):dxbinding(0),glbinding(0),out
+//TEST_INPUT: type Impl
+
+
+/* Testing this scenario:
+The ProgramLayout before specializeIRForEntryPoint() has no global cbuffer
+allocated because the program before specialization does not define any
+global uniform variables.
+
+However, after specialization, we find ourselves needing a global constant
+buffer. The compiler should allocate the next available constant buffer slot
+(here c1 because c0 is already used by the explicit cbuffer decl `existingBuffer`)
+for the newly generated global constant buffer.
+*/
+
+RWStructuredBuffer<float> outputBuffer;
+
+interface IBase
+{
+ float compute();
+}
+
+struct Impl : IBase
+{
+ float base; // = 1.0
+ Texture2D tex;
+ SamplerState sampler;
+ float compute()
+ {
+ return tex.SampleLevel(sampler, float2(0.0), 0.0).x + base;
+ }
+};
+
+__generic_param TImpl : IBase;
+
+TImpl impl;
+
+// at binding c0:
+cbuffer existingBuffer
+{
+ float base0; // = 0.5
+}
+Texture2D tex1; // = 0.0
+SamplerState sampler;
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ uint tid = dispatchThreadID.x;
+ float b0 = tex1.SampleLevel(sampler, float2(0.0), 0.0).x + base0; // = 0.5
+ float outVal = impl.compute(); // = 2.0
+ outVal = b0 / outVal; // = 0.25
+ outputBuffer[tid] = outVal;
+} \ No newline at end of file
diff --git a/tests/compute/global-type-param2.slang.expected.txt b/tests/compute/global-type-param2.slang.expected.txt
new file mode 100644
index 000000000..4846e7be2
--- /dev/null
+++ b/tests/compute/global-type-param2.slang.expected.txt
@@ -0,0 +1 @@
+3E800000
diff --git a/tools/render-test/main.cpp b/tools/render-test/main.cpp
index cb0eb927d..51a96436f 100644
--- a/tools/render-test/main.cpp
+++ b/tools/render-test/main.cpp
@@ -117,6 +117,7 @@ Error initializeShaders(
compileRequest.computeShader.name = computeEntryPointName;
compileRequest.computeShader.profile = computeProfileName;
}
+ compileRequest.entryPointTypeArguments = gShaderInputLayout.globalTypeArguments;
gShaderProgram = shaderCompiler->compileProgram(compileRequest);
if( !gShaderProgram )
{
diff --git a/tools/render-test/render-d3d11.cpp b/tools/render-test/render-d3d11.cpp
index 9bac24094..cdd6c778e 100644
--- a/tools/render-test/render-d3d11.cpp
+++ b/tools/render-test/render-d3d11.cpp
@@ -455,10 +455,18 @@ public:
ID3D11Buffer * buffer = nullptr;
};
+ UInt RoundUpToAlignment(UInt size, UInt alignment)
+ {
+ if (size % alignment)
+ return (size / alignment + 1) * alignment;
+ else
+ return Math::Max(size, alignment);
+ }
+
virtual Buffer* createBuffer(BufferDesc const& desc) override
{
D3D11_BUFFER_DESC dxBufferDesc = { 0 };
- dxBufferDesc.ByteWidth = (UINT) desc.size;
+ dxBufferDesc.ByteWidth = (UINT)RoundUpToAlignment(desc.size, 256);
switch( desc.flavor )
{
@@ -773,7 +781,11 @@ public:
{
auto dxContext = dxImmediateContext;
D3D11_BUFFER_DESC desc = {0};
- desc.ByteWidth = (UINT)(bufferData.Count() * sizeof(unsigned int));
+ List<unsigned int> newBuffer;
+ desc.ByteWidth = (UINT)RoundUpToAlignment((bufferData.Count() * sizeof(unsigned int)), 256);
+ newBuffer.SetSize(desc.ByteWidth / sizeof(unsigned int));
+ for (UInt i = 0; i < bufferData.Count(); i++)
+ newBuffer[i] = bufferData[i];
if (bufferDesc.type == InputBufferType::ConstantBuffer)
{
desc.Usage = D3D11_USAGE_DEFAULT;
@@ -794,7 +806,7 @@ public:
}
}
D3D11_SUBRESOURCE_DATA data = {0};
- data.pSysMem = bufferData.Buffer();
+ data.pSysMem = newBuffer.Buffer();
dxDevice->CreateBuffer(&desc, &data, &bufferOut);
int elemSize = bufferDesc.stride <= 0 ? 1 : bufferDesc.stride;
if (bufferDesc.type == InputBufferType::StorageBuffer)
@@ -1091,7 +1103,7 @@ public:
D3D11_BUFFER_DESC bufDesc;
memset(&bufDesc, 0, sizeof(bufDesc));
bufDesc.BindFlags = 0;
- bufDesc.ByteWidth = binding.bufferLength;
+ bufDesc.ByteWidth = (UINT)RoundUpToAlignment(binding.bufferLength, 256);
bufDesc.CPUAccessFlags = D3D11_CPU_ACCESS_READ;
bufDesc.Usage = D3D11_USAGE_STAGING;
dxDevice->CreateBuffer(&bufDesc, nullptr, &stageBuf);
diff --git a/tools/render-test/render.h b/tools/render-test/render.h
index dec48cda4..174ba0b7b 100644
--- a/tools/render-test/render.h
+++ b/tools/render-test/render.h
@@ -31,6 +31,7 @@ struct ShaderCompileRequest
EntryPoint vertexShader;
EntryPoint fragmentShader;
EntryPoint computeShader;
+ Slang::List<Slang::String> entryPointTypeArguments;
};
class ShaderCompiler
diff --git a/tools/render-test/shader-input-layout.cpp b/tools/render-test/shader-input-layout.cpp
index ef78fe3d5..01328eabd 100644
--- a/tools/render-test/shader-input-layout.cpp
+++ b/tools/render-test/shader-input-layout.cpp
@@ -7,6 +7,7 @@ namespace renderer_test
void ShaderInputLayout::Parse(const char * source)
{
entries.Clear();
+ globalTypeArguments.Clear();
auto lines = Split(source, '\n');
for (auto & line : lines)
{
@@ -16,200 +17,208 @@ namespace renderer_test
TokenReader parser(lineContent);
try
{
- ShaderInputLayoutEntry entry;
-
- if (parser.LookAhead("cbuffer"))
- {
- entry.type = ShaderInputType::Buffer;
- entry.bufferDesc.type = InputBufferType::ConstantBuffer;
- }
- else if (parser.LookAhead("ubuffer"))
- {
- entry.type = ShaderInputType::Buffer;
- entry.bufferDesc.type = InputBufferType::StorageBuffer;
- }
- else if (parser.LookAhead("Texture1D"))
- {
- entry.type = ShaderInputType::Texture;
- entry.textureDesc.dimension = 1;
- }
- else if (parser.LookAhead("Texture2D"))
- {
- entry.type = ShaderInputType::Texture;
- entry.textureDesc.dimension = 2;
- }
- else if (parser.LookAhead("Texture3D"))
- {
- entry.type = ShaderInputType::Texture;
- entry.textureDesc.dimension = 3;
- }
- else if (parser.LookAhead("TextureCube"))
- {
- entry.type = ShaderInputType::Texture;
- entry.textureDesc.dimension = 2;
- entry.textureDesc.isCube = true;
- }
- else if (parser.LookAhead("RWTexture1D"))
- {
- entry.type = ShaderInputType::Texture;
- entry.textureDesc.dimension = 1;
- entry.textureDesc.isRWTexture = true;
- }
- else if (parser.LookAhead("RWTexture2D"))
- {
- entry.type = ShaderInputType::Texture;
- entry.textureDesc.dimension = 2;
- entry.textureDesc.isRWTexture = true;
- }
- else if (parser.LookAhead("RWTexture3D"))
- {
- entry.type = ShaderInputType::Texture;
- entry.textureDesc.dimension = 3;
- entry.textureDesc.isRWTexture = true;
- }
- else if (parser.LookAhead("RWTextureCube"))
- {
- entry.type = ShaderInputType::Texture;
- entry.textureDesc.dimension = 2;
- entry.textureDesc.isCube = true;
- entry.textureDesc.isRWTexture = true;
- }
- else if (parser.LookAhead("Sampler"))
- {
- entry.type = ShaderInputType::Sampler;
- }
- else if (parser.LookAhead("Sampler1D"))
- {
- entry.type = ShaderInputType::CombinedTextureSampler;
- entry.textureDesc.dimension = 1;
- }
- else if (parser.LookAhead("Sampler2D"))
- {
- entry.type = ShaderInputType::CombinedTextureSampler;
- entry.textureDesc.dimension = 2;
- }
- else if (parser.LookAhead("Sampler3D"))
- {
- entry.type = ShaderInputType::CombinedTextureSampler;
- entry.textureDesc.dimension = 3;
- }
- else if (parser.LookAhead("SamplerCube"))
+ if (parser.LookAhead("type"))
{
- entry.type = ShaderInputType::CombinedTextureSampler;
- entry.textureDesc.dimension = 2;
- entry.textureDesc.isCube = true;
+ parser.ReadToken();
+ globalTypeArguments.Add(parser.ReadWord());
}
- else if (parser.LookAhead("render_targets"))
+ else
{
- numRenderTargets = parser.ReadInt();
- continue;
- }
- parser.ReadToken();
- // parse options
- if (parser.LookAhead("("))
- {
- parser.Read("(");
- while (!parser.IsEnd() && !parser.LookAhead(")"))
+ ShaderInputLayoutEntry entry;
+
+ if (parser.LookAhead("cbuffer"))
{
- auto word = parser.ReadWord();
- if (word == "depth")
- {
- entry.textureDesc.isDepthTexture = true;
- }
- else if (word == "depthCompare")
- {
- entry.samplerDesc.isCompareSampler = true;
- }
- else if (word == "arrayLength")
- {
- parser.Read("=");
- entry.textureDesc.arrayLength = parser.ReadInt();
- }
- else if (word == "stride")
- {
- parser.Read("=");
- entry.bufferDesc.stride = parser.ReadInt();
- }
- else if (word == "size")
- {
- parser.Read("=");
- entry.textureDesc.size = parser.ReadInt();
- }
- else if (word == "data")
+ entry.type = ShaderInputType::Buffer;
+ entry.bufferDesc.type = InputBufferType::ConstantBuffer;
+ }
+ else if (parser.LookAhead("ubuffer"))
+ {
+ entry.type = ShaderInputType::Buffer;
+ entry.bufferDesc.type = InputBufferType::StorageBuffer;
+ }
+ else if (parser.LookAhead("Texture1D"))
+ {
+ entry.type = ShaderInputType::Texture;
+ entry.textureDesc.dimension = 1;
+ }
+ else if (parser.LookAhead("Texture2D"))
+ {
+ entry.type = ShaderInputType::Texture;
+ entry.textureDesc.dimension = 2;
+ }
+ else if (parser.LookAhead("Texture3D"))
+ {
+ entry.type = ShaderInputType::Texture;
+ entry.textureDesc.dimension = 3;
+ }
+ else if (parser.LookAhead("TextureCube"))
+ {
+ entry.type = ShaderInputType::Texture;
+ entry.textureDesc.dimension = 2;
+ entry.textureDesc.isCube = true;
+ }
+ else if (parser.LookAhead("RWTexture1D"))
+ {
+ entry.type = ShaderInputType::Texture;
+ entry.textureDesc.dimension = 1;
+ entry.textureDesc.isRWTexture = true;
+ }
+ else if (parser.LookAhead("RWTexture2D"))
+ {
+ entry.type = ShaderInputType::Texture;
+ entry.textureDesc.dimension = 2;
+ entry.textureDesc.isRWTexture = true;
+ }
+ else if (parser.LookAhead("RWTexture3D"))
+ {
+ entry.type = ShaderInputType::Texture;
+ entry.textureDesc.dimension = 3;
+ entry.textureDesc.isRWTexture = true;
+ }
+ else if (parser.LookAhead("RWTextureCube"))
+ {
+ entry.type = ShaderInputType::Texture;
+ entry.textureDesc.dimension = 2;
+ entry.textureDesc.isCube = true;
+ entry.textureDesc.isRWTexture = true;
+ }
+ else if (parser.LookAhead("Sampler"))
+ {
+ entry.type = ShaderInputType::Sampler;
+ }
+ else if (parser.LookAhead("Sampler1D"))
+ {
+ entry.type = ShaderInputType::CombinedTextureSampler;
+ entry.textureDesc.dimension = 1;
+ }
+ else if (parser.LookAhead("Sampler2D"))
+ {
+ entry.type = ShaderInputType::CombinedTextureSampler;
+ entry.textureDesc.dimension = 2;
+ }
+ else if (parser.LookAhead("Sampler3D"))
+ {
+ entry.type = ShaderInputType::CombinedTextureSampler;
+ entry.textureDesc.dimension = 3;
+ }
+ else if (parser.LookAhead("SamplerCube"))
+ {
+ entry.type = ShaderInputType::CombinedTextureSampler;
+ entry.textureDesc.dimension = 2;
+ entry.textureDesc.isCube = true;
+ }
+ else if (parser.LookAhead("render_targets"))
+ {
+ numRenderTargets = parser.ReadInt();
+ continue;
+ }
+ parser.ReadToken();
+ // parse options
+ if (parser.LookAhead("("))
+ {
+ parser.Read("(");
+ while (!parser.IsEnd() && !parser.LookAhead(")"))
{
- parser.Read("=");
- parser.Read("[");
- while (!parser.IsEnd() && !parser.LookAhead("]"))
+ auto word = parser.ReadWord();
+ if (word == "depth")
+ {
+ entry.textureDesc.isDepthTexture = true;
+ }
+ else if (word == "depthCompare")
+ {
+ entry.samplerDesc.isCompareSampler = true;
+ }
+ else if (word == "arrayLength")
{
- if (parser.NextToken().Type == TokenType::IntLiteral)
+ parser.Read("=");
+ entry.textureDesc.arrayLength = parser.ReadInt();
+ }
+ else if (word == "stride")
+ {
+ parser.Read("=");
+ entry.bufferDesc.stride = parser.ReadInt();
+ }
+ else if (word == "size")
+ {
+ parser.Read("=");
+ entry.textureDesc.size = parser.ReadInt();
+ }
+ else if (word == "data")
+ {
+ parser.Read("=");
+ parser.Read("[");
+ while (!parser.IsEnd() && !parser.LookAhead("]"))
{
- entry.bufferData.Add(parser.ReadUInt());
+ if (parser.NextToken().Type == TokenType::IntLiteral)
+ {
+ entry.bufferData.Add(parser.ReadUInt());
+ }
+ else
+ {
+ auto floatNum = parser.ReadFloat();
+ entry.bufferData.Add(*(unsigned int*)&floatNum);
+ }
}
+ parser.Read("]");
+ }
+ else if (word == "content")
+ {
+ parser.Read("=");
+ auto contentWord = parser.ReadWord();
+ if (contentWord == "zero")
+ entry.textureDesc.content = InputTextureContent::Zero;
+ else if (contentWord == "one")
+ entry.textureDesc.content = InputTextureContent::One;
+ else if (contentWord == "chessboard")
+ entry.textureDesc.content = InputTextureContent::ChessBoard;
else
- {
- auto floatNum = parser.ReadFloat();
- entry.bufferData.Add(*(unsigned int*)&floatNum);
- }
+ entry.textureDesc.content = InputTextureContent::Gradient;
}
- parser.Read("]");
- }
- else if (word == "content")
- {
- parser.Read("=");
- auto contentWord = parser.ReadWord();
- if (contentWord == "zero")
- entry.textureDesc.content = InputTextureContent::Zero;
- else if (contentWord == "one")
- entry.textureDesc.content = InputTextureContent::One;
- else if (contentWord == "chessboard")
- entry.textureDesc.content = InputTextureContent::ChessBoard;
+ if (parser.LookAhead(","))
+ parser.Read(",");
else
- entry.textureDesc.content = InputTextureContent::Gradient;
+ break;
}
- if (parser.LookAhead(","))
- parser.Read(",");
- else
- break;
+ parser.Read(")");
}
- parser.Read(")");
- }
- // parse bindings
- if (parser.LookAhead(":"))
- {
- parser.Read(":");
- while (!parser.IsEnd())
+ // parse bindings
+ if (parser.LookAhead(":"))
{
- if (parser.LookAhead("dxbinding"))
- {
- parser.ReadToken();
- parser.Read("(");
- entry.hlslBinding = parser.ReadInt();
- parser.Read(")");
- }
- else if (parser.LookAhead("glbinding"))
+ parser.Read(":");
+ while (!parser.IsEnd())
{
- parser.ReadToken();
- parser.Read("(");
- while (!parser.IsEnd() && !parser.LookAhead(")"))
+ if (parser.LookAhead("dxbinding"))
{
- entry.glslBinding.Add(parser.ReadInt());
- if (parser.LookAhead(","))
- parser.Read(",");
- else
- break;
+ parser.ReadToken();
+ parser.Read("(");
+ entry.hlslBinding = parser.ReadInt();
+ parser.Read(")");
}
- parser.Read(")");
- }
- else if (parser.LookAhead("out"))
- {
- parser.ReadToken();
- entry.isOutput = true;
+ else if (parser.LookAhead("glbinding"))
+ {
+ parser.ReadToken();
+ parser.Read("(");
+ while (!parser.IsEnd() && !parser.LookAhead(")"))
+ {
+ entry.glslBinding.Add(parser.ReadInt());
+ if (parser.LookAhead(","))
+ parser.Read(",");
+ else
+ break;
+ }
+ parser.Read(")");
+ }
+ else if (parser.LookAhead("out"))
+ {
+ parser.ReadToken();
+ entry.isOutput = true;
+ }
+ if (parser.LookAhead(","))
+ parser.Read(",");
}
- if (parser.LookAhead(","))
- parser.Read(",");
}
+ entries.Add(entry);
}
- entries.Add(entry);
}
catch (TextFormatException)
{
diff --git a/tools/render-test/shader-input-layout.h b/tools/render-test/shader-input-layout.h
index 9602e4fe8..c4c3d9d8c 100644
--- a/tools/render-test/shader-input-layout.h
+++ b/tools/render-test/shader-input-layout.h
@@ -63,6 +63,7 @@ namespace renderer_test
{
public:
Slang::List<ShaderInputLayoutEntry> entries;
+ Slang::List<Slang::String> globalTypeArguments;
int numRenderTargets = 1;
void Parse(const char * source);
};
diff --git a/tools/render-test/slang-support.cpp b/tools/render-test/slang-support.cpp
index 63e24126c..746967cb7 100644
--- a/tools/render-test/slang-support.cpp
+++ b/tools/render-test/slang-support.cpp
@@ -84,7 +84,14 @@ struct SlangShaderCompilerWrapper : public ShaderCompiler
ShaderProgram * result = nullptr;
if (request.computeShader.name)
{
- int computeEntryPoint = spAddEntryPoint(slangRequest, computeTranslationUnit, computeEntryPointName, spFindProfile(slangSession, request.computeShader.profile));
+ Slang::List<const char*> rawTypeNames;
+ for (auto typeName : request.entryPointTypeArguments)
+ rawTypeNames.Add(typeName.Buffer());
+ int computeEntryPoint = spAddEntryPointEx(slangRequest, computeTranslationUnit,
+ computeEntryPointName,
+ spFindProfile(slangSession, request.computeShader.profile),
+ (int)rawTypeNames.Count(),
+ rawTypeNames.Buffer());
int compileErr = spCompile(slangRequest);
if (auto diagnostics = spGetDiagnosticOutput(slangRequest))
{