summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2017-11-17 21:26:21 -0500
committerTim Foley <tfoleyNV@users.noreply.github.com>2017-11-17 18:26:21 -0800
commit54bf54bd0dda378f8400860b25855558f39cb52b (patch)
tree955931f37df819f3c6e22bc981089f644c1141e1
parent0298a0427bbfe19700169c4e239a1b9e91baa410 (diff)
Add support for global generic parameters (#285)
* Add support for global generic parameters (In-progress work) This commit include: 1. Update Slang API to allow specification of generic type arguments in an `EntryPointRequest` 2. Add parsing of `__generic_param` construct, which becomes a GlobalGenericParamDecl, contains members of `GenericTypeConstraintDecl`. 3. Semantics checking will check whether the provided type arguments conform to the interfaces as defined by the generic parameter, and store SubtypeWitness values in the EntryPointRequest, which will be used by `specializeIRForEntryPoint` when generating final IR. 4. Add a new type of substitution - `GlobalGenericParamSubstitution` for subsittuting references to `__generic_param` decls or to its member `GenericTypeConsraintDecl` with the actual type argument or witness tables. 5. Update `IRSpecContext` to apply `GlobalGenericParamSubstitution` when specializing the IR for an EntryPointRequest. 6. Update `render-test` to take additional `type` inputs, which specifies the type arguments to substitute into the global `__generic_param` types. This commit does not include ProgramLayout specialization. * IR: pass through `[unroll]` attribute (#284) The initial lowering was adding an `IRLoopControlDecoration` to the instruction at the head of a loop, but this was getting dropped when the IR gets cloned for a particular entry point. The fix was simply to add a case for loop-control decorations to `cloneDecoration`. * fix warnings * IR: support `CompileTimeForStmt` (#286) This statement type is a bit of a hack, to support loops that *must* be unrolled. The AST-to-AST pass handles them by cloning the AST for the loop body N times, and it was easy enough to do the same thing for the IR: emit the instructions for the body N times. The only thing that requires a bit of care is that now we might see the same variable declarations multiple times, so we need to play it safe and overwrite existing entries in our map from declarations to their IR values. Of course a better answer long-term would be to do the actual unrolling in the IR. This is especially true because we might some day want to support compile-time/must-unroll loops in functions, where the loop counter comes in as a parameter (but must still be compile-time-constant at every call site). * Add support for global generic parameters (In-progress work) This commit include: 1. Update Slang API to allow specification of generic type arguments in an `EntryPointRequest` 2. Add parsing of `__generic_param` construct, which becomes a GlobalGenericParamDecl, contains members of `GenericTypeConstraintDecl`. 3. Semantics checking will check whether the provided type arguments conform to the interfaces as defined by the generic parameter, and store SubtypeWitness values in the EntryPointRequest, which will be used by `specializeIRForEntryPoint` when generating final IR. 4. Add a new type of substitution - `GlobalGenericParamSubstitution` for subsittuting references to `__generic_param` decls or to its member `GenericTypeConsraintDecl` with the actual type argument or witness tables. 5. Update `IRSpecContext` to apply `GlobalGenericParamSubstitution` when specializing the IR for an EntryPointRequest. 6. Update `render-test` to take additional `type` inputs, which specifies the type arguments to substitute into the global `__generic_param` types. progress on parameter binding * Add a more contrived test case for specializing parameter bindings * update render-test to align buffers to 256 bytes (to get rid of D3D complains on minimal buffer size). * adding one more test case for parameter binding specialization. * Cleanup according to @tfoleyNV 's suggestions. * fix a bug introduced in the cleanup
-rw-r--r--slang.h28
-rw-r--r--source/core/list.h8
-rw-r--r--source/slang/check.cpp102
-rw-r--r--source/slang/compiler.cpp6
-rw-r--r--source/slang/compiler.h12
-rw-r--r--source/slang/decl-defs.h5
-rw-r--r--source/slang/diagnostic-defs.h9
-rw-r--r--source/slang/emit.cpp11
-rw-r--r--source/slang/emit.h5
-rw-r--r--source/slang/ir-insts.h4
-rw-r--r--source/slang/ir.cpp119
-rw-r--r--source/slang/lookup.cpp7
-rw-r--r--source/slang/lower-to-ir.cpp4
-rw-r--r--source/slang/lower.cpp7
-rw-r--r--source/slang/parameter-binding.cpp239
-rw-r--r--source/slang/parser.cpp103
-rw-r--r--source/slang/reflection.cpp19
-rw-r--r--source/slang/slang.cpp32
-rw-r--r--source/slang/syntax-base-defs.h32
-rw-r--r--source/slang/syntax.cpp62
-rw-r--r--source/slang/syntax.h2
-rw-r--r--source/slang/type-layout.cpp32
-rw-r--r--source/slang/type-layout.h20
-rw-r--r--tests/compute/global-type-param.slang30
-rw-r--r--tests/compute/global-type-param.slang.expected.txt1
-rw-r--r--tests/compute/global-type-param1.slang46
-rw-r--r--tests/compute/global-type-param1.slang.expected.txt1
-rw-r--r--tests/compute/global-type-param2.slang61
-rw-r--r--tests/compute/global-type-param2.slang.expected.txt1
-rw-r--r--tools/render-test/main.cpp1
-rw-r--r--tools/render-test/render-d3d11.cpp20
-rw-r--r--tools/render-test/render.h1
-rw-r--r--tools/render-test/shader-input-layout.cpp357
-rw-r--r--tools/render-test/shader-input-layout.h1
-rw-r--r--tools/render-test/slang-support.cpp9
35 files changed, 1123 insertions, 274 deletions
diff --git a/slang.h b/slang.h
index c4c2f54b7..30808a242 100644
--- a/slang.h
+++ b/slang.h
@@ -377,6 +377,18 @@ extern "C"
char const* name,
SlangProfileID profile);
+ /** Add an entry point in a particular translation unit,
+ with additional arguments that specify the concrete
+ type names for global generic type parameters.
+ */
+ SLANG_API int spAddEntryPointEx(
+ SlangCompileRequest* request,
+ int translationUnitIndex,
+ char const* name,
+ SlangProfileID profile,
+ int genericTypeNameCount,
+ char const** genericTypeNames);
+
/** Execute the compilation request.
Returns zero on success, non-zero on failure.
@@ -588,6 +600,9 @@ extern "C"
// HLSL register `space`, Vulkan GLSL `set`
SLANG_PARAMETER_CATEGORY_REGISTER_SPACE,
+ // A parameter whose type is to be specialized by a global generic type argument
+ SLANG_PARAMETER_CATEGORY_GENERIC,
+
//
SLANG_PARAMETER_CATEGORY_COUNT,
};
@@ -695,6 +710,8 @@ extern "C"
SLANG_API SlangUInt spReflection_getEntryPointCount(SlangReflection* reflection);
SLANG_API SlangReflectionEntryPoint* spReflection_getEntryPointByIndex(SlangReflection* reflection, SlangUInt index);
+ SLANG_API SlangUInt spReflection_getGlobalConstantBufferBinding(SlangReflection* reflection);
+ SLANG_API size_t spReflection_getGlobalConstantBufferSize(SlangReflection* reflection);
#ifdef __cplusplus
}
@@ -848,6 +865,7 @@ namespace slang
SpecializationConstant = SLANG_PARAMETER_CATEGORY_SPECIALIZATION_CONSTANT,
PushConstantBuffer = SLANG_PARAMETER_CATEGORY_PUSH_CONSTANT_BUFFER,
RegisterSpace = SLANG_PARAMETER_CATEGORY_REGISTER_SPACE,
+ GenericResource = SLANG_PARAMETER_CATEGORY_GENERIC,
};
struct TypeLayoutReflection
@@ -1102,6 +1120,16 @@ namespace slang
{
return (EntryPointReflection*) spReflection_getEntryPointByIndex((SlangReflection*) this, index);
}
+
+ SlangUInt getGlobalConstantBufferBinding()
+ {
+ return spReflection_getGlobalConstantBufferBinding((SlangReflection*)this);
+ }
+
+ size_t getGlobalConstantBufferSize()
+ {
+ return spReflection_getGlobalConstantBufferSize((SlangReflection*)this);
+ }
};
}
diff --git a/source/core/list.h b/source/core/list.h
index af32a39ef..b1461a260 100644
--- a/source/core/list.h
+++ b/source/core/list.h
@@ -487,7 +487,7 @@ namespace Slang
if (predicate(buffer[i]))
return i;
}
- return -1;
+ return (UInt)-1;
}
template<typename Func>
@@ -498,7 +498,7 @@ namespace Slang
if (predicate(buffer[i-1]))
return i-1;
}
- return -1;
+ return (UInt)-1;
}
template<typename T2>
@@ -509,7 +509,7 @@ namespace Slang
if (buffer[i] == val)
return i;
}
- return -1;
+ return (UInt)-1;
}
template<typename T2>
@@ -520,7 +520,7 @@ namespace Slang
if(buffer[i-1] == val)
return i-1;
}
- return -1;
+ return (UInt)-1;
}
void Sort()
diff --git a/source/slang/check.cpp b/source/slang/check.cpp
index 233a82eef..4b8f4f4c1 100644
--- a/source/slang/check.cpp
+++ b/source/slang/check.cpp
@@ -148,7 +148,6 @@ namespace Slang
return expr->type->As<DeclRefType>();
}
-
RefPtr<Expr> ConstructDeclRefExpr(
DeclRef<Decl> declRef,
RefPtr<Expr> baseExpr,
@@ -1998,6 +1997,22 @@ namespace Slang
decl->SetCheckState(DeclCheckState::Checked);
}
+ void visitGlobalGenericParamDecl(GlobalGenericParamDecl * decl)
+ {
+ if (decl->IsChecked(DeclCheckState::Checked)) return;
+ decl->SetCheckState(DeclCheckState::CheckedHeader);
+ // global generic param only allowed in global scope
+ auto program = decl->ParentDecl->As<ModuleDecl>();
+ if (!program)
+ getSink()->diagnose(decl, Slang::Diagnostics::globalGenParamInGlobalScopeOnly);
+ // Now check all of the member declarations.
+ for (auto member : decl->Members)
+ {
+ checkDecl(member);
+ }
+ decl->SetCheckState(DeclCheckState::Checked);
+ }
+
void visitAssocTypeDecl(AssocTypeDecl* decl)
{
if (decl->IsChecked(DeclCheckState::Checked)) return;
@@ -3703,6 +3718,19 @@ namespace Slang
return true;
}
}
+ // if an inheritance decl is not found, try to find a GenericTypeConstraintDecl
+ for (auto genConstraintDeclRef : getMembersOfType<GenericTypeConstraintDecl>(aggTypeDeclRef))
+ {
+ EnsureDecl(genConstraintDeclRef.getDecl());
+ auto inheritedType = GetSup(genConstraintDeclRef);
+ TypeWitnessBreadcrumb breadcrumb;
+ breadcrumb.prev = inBreadcrumbs;
+ breadcrumb.declRef = genConstraintDeclRef;
+ if (doesTypeConformToInterfaceImpl(originalType, inheritedType, interfaceDeclRef, outWitness, &breadcrumb))
+ {
+ return true;
+ }
+ }
}
else if( auto genericTypeParamDeclRef = declRef.As<GenericTypeParamDecl>() )
{
@@ -6582,6 +6610,78 @@ namespace Slang
// that we don't have to re-do this effort again later.
entryPoint->decl = entryPointFuncDecl;
+ // Lookup generic parameter types in global scope
+ for (auto name : entryPoint->genericParameterTypeNames)
+ {
+ if (!translationUnitSyntax->memberDictionary.TryGetValue(name, firstDeclWithName))
+ {
+ // If there doesn't appear to be any such declaration, then
+ // we need to diagnose it as an error, and then bail out.
+ sink->diagnose(translationUnitSyntax, Diagnostics::entryPointTypeParameterNotFound, name);
+ return;
+ }
+ RefPtr<Type> type;
+ if (auto aggType = firstDeclWithName->As<AggTypeDecl>())
+ {
+ type = DeclRefType::Create(entryPoint->compileRequest->mSession, DeclRef<Decl>(aggType, nullptr));
+ }
+ else if (auto typeDefDecl = firstDeclWithName->As<TypeDefDecl>())
+ {
+ type = GetType(DeclRef<TypeDefDecl>(typeDefDecl, nullptr));
+ }
+ else
+ {
+ sink->diagnose(firstDeclWithName, Diagnostics::entryPointTypeSymbolNotAType, name);
+ return;
+ }
+ entryPoint->genericParameterTypes.Add(type);
+ }
+ // check that user-provioded type arguments conforms to the generic type
+ // parameter declaration of this translation unit
+
+ // collect global generic parameters from all imported modules
+ List<RefPtr<GlobalGenericParamDecl>> globalGenericParams;
+ // add current translation unit first
+ {
+ auto globalGenParams = translationUnit->SyntaxNode->getMembersOfType<GlobalGenericParamDecl>();
+ for (auto p : globalGenParams)
+ globalGenericParams.Add(p);
+ }
+ // add imported modules
+ for (auto moduleDecl : entryPoint->compileRequest->loadedModulesList)
+ {
+ auto globalGenParams = moduleDecl->getMembersOfType<GlobalGenericParamDecl>();
+ for (auto p : globalGenParams)
+ globalGenericParams.Add(p);
+ }
+ if (globalGenericParams.Count() != entryPoint->genericParameterTypes.Count())
+ {
+ sink->diagnose(entryPoint->decl, Diagnostics::mismatchEntryPointTypeArgument, globalGenericParams.Count(),
+ entryPoint->genericParameterTypes.Count());
+ return;
+ }
+ // if number of entry-point type arguments matches parameters, try find
+ // SubtypeWitness for each argument
+ int index = 0;
+ for (auto & gParam : globalGenericParams)
+ {
+ for (auto constraint : gParam->getMembersOfType<GenericTypeConstraintDecl>())
+ {
+ auto interfaceType = GetSup(DeclRef<GenericTypeConstraintDecl>(constraint, nullptr));
+ SemanticsVisitor visitor(sink, entryPoint->compileRequest, translationUnit);
+ auto witness = visitor.tryGetSubtypeWitness(entryPoint->genericParameterTypes[index], interfaceType);
+ if (!witness)
+ {
+ sink->diagnose(gParam,
+ Diagnostics::typeArgumentDoesNotConformToInterface, gParam->nameAndLoc.name, entryPoint->genericParameterTypes[index],
+ interfaceType);
+ }
+ entryPoint->genericParameterWitnesses.Add(witness);
+ }
+ index++;
+ }
+ if (sink->errorCount != 0)
+ return;
// TODO: after all that work, we are now in a position to start
// validating the declaration itself. E.g., we should check if
// the declared input/output parameters have suitable semantics,
diff --git a/source/slang/compiler.cpp b/source/slang/compiler.cpp
index 302b5704f..acbf51e2e 100644
--- a/source/slang/compiler.cpp
+++ b/source/slang/compiler.cpp
@@ -11,7 +11,7 @@
#include "parser.h"
#include "preprocessor.h"
#include "syntax-visitors.h"
-
+#include "type-layout.h"
#include "reflection.h"
#include "emit.h"
@@ -160,7 +160,7 @@ namespace Slang
entryPoint,
targetReq->layout.Ptr(),
CodeGenTarget::HLSL,
- targetReq->target);
+ targetReq);
}
}
@@ -207,7 +207,7 @@ namespace Slang
entryPoint,
targetReq->layout.Ptr(),
CodeGenTarget::GLSL,
- targetReq->target);
+ targetReq);
}
}
diff --git a/source/slang/compiler.h b/source/slang/compiler.h
index f42f36c1f..303be6624 100644
--- a/source/slang/compiler.h
+++ b/source/slang/compiler.h
@@ -100,6 +100,10 @@ namespace Slang
// The name of the entry point function (e.g., `main`)
Name* name;
+
+ // The type names we want to substitute into the
+ // global generic type parameters
+ List<Name*> genericParameterTypeNames;
// The profile that the entry point will be compiled for
// (this is a combination of the target state, and also
@@ -123,6 +127,11 @@ namespace Slang
// it should not be assumed to be available in cases
// where any errors were diagnosed.
RefPtr<FuncDecl> decl;
+
+ // The declaration of the global generic parameter types
+ // This will be filled in as part of semantic analysis.
+ List<RefPtr<Type>> genericParameterTypes;
+ List<RefPtr<Val>> genericParameterWitnesses;
};
enum class PassThroughMode : SlangPassThrough
@@ -319,7 +328,8 @@ namespace Slang
int addEntryPoint(
int translationUnitIndex,
String const& name,
- Profile profile);
+ Profile profile,
+ List<String> const & genericTypeNames);
UInt addTarget(
CodeGenTarget target);
diff --git a/source/slang/decl-defs.h b/source/slang/decl-defs.h
index 9c010d156..e24a535c5 100644
--- a/source/slang/decl-defs.h
+++ b/source/slang/decl-defs.h
@@ -126,6 +126,11 @@ END_SYNTAX_CLASS()
SYNTAX_CLASS(AssocTypeDecl, AggTypeDecl)
END_SYNTAX_CLASS()
+// A '__generic_param' declaration, which defines a generic
+// entry-point parameter. Is a container of GenericTypeConstraintDecl
+SYNTAX_CLASS(GlobalGenericParamDecl, AggTypeDecl)
+END_SYNTAX_CLASS()
+
// A scope for local declarations (e.g., as part of a statement)
SIMPLE_SYNTAX_CLASS(ScopeDecl, ContainerDecl)
diff --git a/source/slang/diagnostic-defs.h b/source/slang/diagnostic-defs.h
index 7f27e43e8..24e8bc713 100644
--- a/source/slang/diagnostic-defs.h
+++ b/source/slang/diagnostic-defs.h
@@ -196,7 +196,7 @@ DIAGNOSTIC(33070, Error, expectedFunction, "expression preceding parenthesis of
// 303xx: interfaces and associated types
DIAGNOSTIC(30300, Error, assocTypeInInterfaceOnly, "'associatedtype' can only be defined in an 'interface'.")
-
+DIAGNOSTIC(30301, Error, globalGenParamInGlobalScopeOnly, "'__generic_param' can only be defined global scope.")
// TODO: need to assign numbers to all these extra diagnostics...
DIAGNOSTIC(39999, Error, expectedIntegerConstantWrongType, "expected integer constant (found: '$0')")
@@ -244,11 +244,17 @@ DIAGNOSTIC(38001, Error, ambiguousEntryPoint, "more than one function matches en
DIAGNOSTIC(38002, Note, entryPointCandidate, "see candidate declaration for entry point '$0'")
DIAGNOSTIC(38003, Error, entryPointSymbolNotAFunction, "entry point '$0' must be declared as a function")
+DIAGNOSTIC(38004, Error, entryPointTypeParameterNotFound, "no type found matching entry-point type parameter name '$0'")
+DIAGNOSTIC(38005, Error, entryPointTypeSymbolNotAType, "entry-point type parameter '$0' must be declared as a type")
+
DIAGNOSTIC(38100, Error, typeDoesntImplementInterfaceRequirement, "type '$0' does not provide required interface member '$1'")
DIAGNOSTIC(38101, Error, thisExpressionOutsideOfTypeDecl, "'this' expression can only be used in members of an aggregate type")
DIAGNOSTIC(38102, Error, initializerNotInsideType, "an 'init' declaration is only allowed inside a type or 'extension' declaration")
DIAGNOSTIC(38102, Error, accessorMustBeInsideSubscriptOrProperty, "an accessor declaration is only allowed inside a subscript or property declaration")
+DIAGNOSTIC(38020, Error, mismatchEntryPointTypeArgument, "expecting $0 entry-point type arguments, provided $1.")
+DIAGNOSTIC(38021, Error, typeArgumentDoesNotConformToInterface, "type argument `$1` for generic parameter `$0` does not conform to interface `$1`.")
+
//
// 4xxxx - IL code generation.
//
@@ -264,7 +270,6 @@ DIAGNOSTIC(49999, Error, unknownSystemValueSemantic, "unknown system-value seman
//
// 5xxxx - Target code generation.
//
-
DIAGNOSTIC(50020, Error, unknownStageType, "Unknown stage type '$0'.")
DIAGNOSTIC(50020, Error, invalidTessCoordType, "TessCoord must have vec2 or vec3 type.")
DIAGNOSTIC(50020, Error, invalidFragCoordType, "FragCoord must be a vec4.")
diff --git a/source/slang/emit.cpp b/source/slang/emit.cpp
index 5b7a42ad7..614e8f474 100644
--- a/source/slang/emit.cpp
+++ b/source/slang/emit.cpp
@@ -3481,9 +3481,9 @@ struct EmitVisitor
break;
case LayoutResourceKind::RegisterSpace:
+ case LayoutResourceKind::GenericResource:
// ignore
break;
-
default:
{
Emit(": register(");
@@ -6771,7 +6771,7 @@ EntryPointLayout* findEntryPointLayout(
StructTypeLayout* getGlobalStructLayout(
ProgramLayout* programLayout)
{
- auto globalScopeLayout = programLayout->globalScopeLayout;
+ auto globalScopeLayout = programLayout->globalScopeLayout->typeLayout;
if( auto gs = globalScopeLayout.As<StructTypeLayout>() )
{
return gs.Ptr();
@@ -6816,13 +6816,13 @@ String emitEntryPoint(
EntryPointRequest* entryPoint,
ProgramLayout* programLayout,
CodeGenTarget target,
- CodeGenTarget finalTarget)
+ TargetRequest* targetRequest)
{
auto translationUnit = entryPoint->getTranslationUnit();
SharedEmitContext sharedContext;
sharedContext.target = target;
- sharedContext.finalTarget = finalTarget;
+ sharedContext.finalTarget = targetRequest->target;
sharedContext.entryPoint = entryPoint;
if (entryPoint)
@@ -6890,7 +6890,8 @@ String emitEntryPoint(
auto lowered = specializeIRForEntryPoint(
entryPoint,
programLayout,
- target);
+ target,
+ targetRequest);
// If the user specified the flag that they want us to dump
// IR, then do it here, for the target-specific, but
diff --git a/source/slang/emit.h b/source/slang/emit.h
index e17a84d5a..98845f9c6 100644
--- a/source/slang/emit.h
+++ b/source/slang/emit.h
@@ -26,8 +26,7 @@ namespace Slang
// The target language to generate code in (e.g., HLSL/GLSL)
CodeGenTarget target,
- // The "final" target that we are being asked to compile for
- // (e.g., SPIR-V, DXBC, ...).
- CodeGenTarget finalTarget);
+ // The full target request
+ TargetRequest* targetRequest);
}
#endif
diff --git a/source/slang/ir-insts.h b/source/slang/ir-insts.h
index 52acf6576..a91143a43 100644
--- a/source/slang/ir-insts.h
+++ b/source/slang/ir-insts.h
@@ -419,6 +419,7 @@ struct IRBuilder
IRFunc* createFunc();
IRGlobalVar* createGlobalVar(
IRType* valueType);
+ IRWitnessTable* createWitnessTable(Dictionary<DeclRef<Decl>, Decl*> & witnesses);
IRWitnessTable* createWitnessTable();
IRWitnessTableEntry* createWitnessTableEntry(
IRWitnessTable* witnessTable,
@@ -565,7 +566,8 @@ struct IRBuilder
IRModule* specializeIRForEntryPoint(
EntryPointRequest* entryPointRequest,
ProgramLayout* programLayout,
- CodeGenTarget target);
+ CodeGenTarget target,
+ TargetRequest* targetReq);
// Find suitable uses of the `specialize` instruction that
// can be replaced with references to specialized functions.
diff --git a/source/slang/ir.cpp b/source/slang/ir.cpp
index 9068e717b..bfc26643c 100644
--- a/source/slang/ir.cpp
+++ b/source/slang/ir.cpp
@@ -3089,12 +3089,16 @@ namespace Slang
// to the layout to use for it.
Dictionary<String, VarLayout*> globalVarLayouts;
+ RefPtr<GlobalGenericParamSubstitution> subst;
+
// Override the "maybe clone" logic so that we always clone
virtual IRValue* maybeCloneValue(IRValue* originalVal) override;
// Override teh "maybe clone" logic so that we carefully
// clone any IR proxy values inside substitutions
virtual DeclRef<Decl> maybeCloneDeclRef(DeclRef<Decl> const& declRef) override;
+
+ virtual RefPtr<Type> maybeCloneType(Type* originalType) override;
};
@@ -3102,6 +3106,11 @@ namespace Slang
IRFunc* cloneFunc(IRSpecContext* context, IRFunc* originalFunc);
IRWitnessTable* cloneWitnessTable(IRSpecContext* context, IRWitnessTable* originalVar);
+ RefPtr<Type> IRSpecContext::maybeCloneType(Type* originalType)
+ {
+ return originalType->Substitute(subst).As<Type>();
+ }
+
IRValue* IRSpecContext::maybeCloneValue(IRValue* originalValue)
{
switch (originalValue->op)
@@ -3143,6 +3152,33 @@ namespace Slang
case kIROp_decl_ref:
{
IRDeclRef* od = (IRDeclRef*)originalValue;
+
+ // if the declRef is one of the __generic_param decl being substituted by subst
+ // return the substituted decl
+ if (subst)
+ {
+ if (od->declRef.getDecl() == subst->paramDecl)
+ return builder->getTypeVal(subst->actualType.As<Type>());
+ else if (auto genConstraint = od->declRef.As<GenericTypeConstraintDecl>())
+ {
+ // a decl-ref to GenericTypeConstraintDecl as a result of
+ // referencing a generic parameter type should be replaced with
+ // the actual witness table
+ if (genConstraint.getDecl()->ParentDecl == subst->paramDecl)
+ {
+ // find the witness table from subst
+ for (auto witness : subst->witnessTables)
+ {
+ if (witness.Key->EqualsVal(GetSup(genConstraint)))
+ {
+ auto proxyVal = witness.Value.As<IRProxyVal>();
+ SLANG_ASSERT(proxyVal);
+ return proxyVal->inst;
+ }
+ }
+ }
+ }
+ }
auto declRef = maybeCloneDeclRef(od->declRef);
return builder->getDeclRefVal(declRef);
}
@@ -3150,7 +3186,9 @@ namespace Slang
case kIROp_TypeType:
{
IRValue* od = (IRValue*)originalValue;
- return builder->getTypeVal(od->type);
+ int ioDiff = 0;
+ auto newType = od->type->SubstituteImpl(subst, &ioDiff);
+ return builder->getTypeVal(newType.As<Type>());
}
break;
default:
@@ -3207,7 +3245,9 @@ namespace Slang
newSubst->outer = cloneSubstitutions(context, subst->outer);
return newSubst;
}
- return nullptr;
+ else
+ SLANG_UNREACHABLE("unimplemented cloneSubstitution");
+ UNREACHABLE_RETURN(nullptr);
}
DeclRef<Decl> IRSpecContext::maybeCloneDeclRef(DeclRef<Decl> const& declRef)
@@ -3281,7 +3321,7 @@ namespace Slang
IRGlobalVar* cloneGlobalVar(IRSpecContext* context, IRGlobalVar* originalVar)
{
- auto clonedVar = context->builder->createGlobalVar(originalVar->getType()->getValueType());
+ auto clonedVar = context->builder->createGlobalVar(context->maybeCloneType(originalVar->getType()->getValueType()));
registerClonedValue(context, clonedVar, originalVar);
auto mangledName = originalVar->mangledName;
@@ -3703,10 +3743,67 @@ namespace Slang
}
}
+ // implementation provided in parameter-binding.cpp
+ RefPtr<ProgramLayout> specializeProgramLayout(
+ TargetRequest * targetReq,
+ ProgramLayout* programLayout,
+ Substitutions * typeSubst);
+
+ RefPtr<GlobalGenericParamSubstitution> createGlobalGenericParamSubstitution(
+ EntryPointRequest * entryPointRequest,
+ ProgramLayout * programLayout,
+ IRSpecContext* context,
+ IRModule* originalIRModule)
+ {
+ RefPtr<GlobalGenericParamSubstitution> globalParamSubst;
+ Substitutions * curTailSubst = nullptr;
+ for (auto param : programLayout->globalGenericParams)
+ {
+ auto paramSubst = new GlobalGenericParamSubstitution();
+ if (!globalParamSubst)
+ globalParamSubst = paramSubst;
+ if (curTailSubst)
+ curTailSubst->outer = paramSubst;
+ curTailSubst = paramSubst;
+ paramSubst->paramDecl = param->decl;
+ SLANG_ASSERT((UInt)param->index < entryPointRequest->genericParameterTypes.Count());
+ paramSubst->actualType = entryPointRequest->genericParameterTypes[param->index];
+ // find witness tables
+ for (auto witness : entryPointRequest->genericParameterWitnesses)
+ {
+ if (auto subtypeWitness = witness.As<SubtypeWitness>())
+ {
+ if (subtypeWitness->sub->EqualsVal(paramSubst->actualType))
+ {
+ auto witnessTableName = getMangledNameForConformanceWitness(subtypeWitness->sub, subtypeWitness->sup);
+ auto globalVar = originalIRModule->getFirstGlobalValue();
+ IRGlobalValue * table = nullptr;
+ while (globalVar)
+ {
+ if (globalVar->mangledName == witnessTableName)
+ {
+ table = globalVar;
+ break;
+ }
+ globalVar = globalVar->getNextValue();
+ }
+ SLANG_ASSERT(table);
+ table = cloneWitnessTable(context, (IRWitnessTable*)(table));
+ IRProxyVal * tableVal = new IRProxyVal();
+ tableVal->inst = table;
+ paramSubst->witnessTables.Add(KeyValuePair<RefPtr<Type>, RefPtr<Val>>(subtypeWitness->sup, tableVal));
+ }
+ }
+ }
+ }
+ return globalParamSubst;
+ }
+
IRModule* specializeIRForEntryPoint(
EntryPointRequest* entryPointRequest,
ProgramLayout* programLayout,
- CodeGenTarget target)
+ CodeGenTarget target,
+ TargetRequest* targetReq)
{
auto compileRequest = entryPointRequest->compileRequest;
auto session = compileRequest->mSession;
@@ -3720,8 +3817,6 @@ namespace Slang
return nullptr;
}
- auto entryPointLayout = findEntryPointLayout(programLayout, entryPointRequest);
-
// We now need to start cloning IR symbols from `originalIRModule`
// into a fresh IR module for this entry point. Along the way we need to:
//
@@ -3746,11 +3841,21 @@ namespace Slang
context->builder = &sharedContextStorage.builderStorage;
context->target = target;
+ // Create the GlobalGenericParamSubstitution for substituting global generic types
+ // into user-provided type arguments
+ auto globalParamSubst = createGlobalGenericParamSubstitution(entryPointRequest, programLayout, context, originalIRModule);
+
+ context->subst = globalParamSubst;
+
+ // now specailize the program layout using the substitution
+ RefPtr<ProgramLayout> newProgramLayout = specializeProgramLayout(targetReq, programLayout, globalParamSubst);
+
+ auto entryPointLayout = findEntryPointLayout(newProgramLayout, entryPointRequest);
// Next, we want to optimize lookup for layout infromation
// associated with global declarations, so that we can
// look things up based on the IR values (using mangled names)
- auto globalStructLayout = getGlobalStructLayout(programLayout);
+ auto globalStructLayout = getGlobalStructLayout(newProgramLayout);
for (auto globalVarLayout : globalStructLayout->fields)
{
String mangledName = getMangledName(globalVarLayout->varDecl);
diff --git a/source/slang/lookup.cpp b/source/slang/lookup.cpp
index b01732362..86bef3f4d 100644
--- a/source/slang/lookup.cpp
+++ b/source/slang/lookup.cpp
@@ -410,9 +410,9 @@ void lookUpMemberImpl(
if (auto declRefType = type->As<DeclRefType>())
{
auto declRef = declRefType->declRef;
- if (auto assocTypeDeclRef = declRef.As<AssocTypeDecl>())
+ if (declRef.As<AssocTypeDecl>() || declRef.As<GlobalGenericParamDecl>())
{
- for (auto constraintDeclRef : getMembersOfType<GenericTypeConstraintDecl>(assocTypeDeclRef))
+ for (auto constraintDeclRef : getMembersOfType<GenericTypeConstraintDecl>(declRef.As<ContainerDecl>()))
{
// The super-type in the constraint (e.g., `Foo` in `T : Foo`)
// will tell us a type we should use for lookup.
@@ -488,5 +488,4 @@ LookupResult lookUpMember(
return result;
}
-
-}
+} \ No newline at end of file
diff --git a/source/slang/lower-to-ir.cpp b/source/slang/lower-to-ir.cpp
index 326d25649..0f3e85805 100644
--- a/source/slang/lower-to-ir.cpp
+++ b/source/slang/lower-to-ir.cpp
@@ -3538,7 +3538,9 @@ static void lowerEntryPointToIR(
// the entry point request.
return;
}
-
+ // we need to lower all global type arguments as well
+ for (auto arg : entryPointRequest->genericParameterTypes)
+ lowerType(context, arg);
auto loweredEntryPointFunc = lowerDecl(context, entryPointFuncDecl);
}
diff --git a/source/slang/lower.cpp b/source/slang/lower.cpp
index b375fa80e..5a6603add 100644
--- a/source/slang/lower.cpp
+++ b/source/slang/lower.cpp
@@ -2870,6 +2870,13 @@ struct LoweringVisitor
UNREACHABLE_RETURN(LoweredDecl());
}
+ LoweredDecl visitGlobalGenericParamDecl(GlobalGenericParamDecl * /*decl*/)
+ {
+ // not supported
+ SLANG_UNREACHABLE("visitGlobalGenericParamDecl in LowerVisitor");
+ UNREACHABLE_RETURN(LoweredDecl());
+ }
+
LoweredDecl visitTypeDefDecl(TypeDefDecl* decl)
{
if (shared->target == CodeGenTarget::GLSL)
diff --git a/source/slang/parameter-binding.cpp b/source/slang/parameter-binding.cpp
index fa015186b..836ed254f 100644
--- a/source/slang/parameter-binding.cpp
+++ b/source/slang/parameter-binding.cpp
@@ -667,6 +667,17 @@ static void collectGlobalScopeGLSLVaryingParameter(
}
// Collect a single declaration into our set of parameters
+static void collectGlobalGenericParameter(
+ ParameterBindingContext* context,
+ RefPtr<GlobalGenericParamDecl> paramDecl)
+{
+ RefPtr<GenericParamLayout> layout = new GenericParamLayout();
+ layout->decl = paramDecl;
+ layout->index = (int)context->shared->programLayout->globalGenericParams.Count();
+ context->shared->programLayout->globalGenericParams.Add(layout);
+}
+
+// Collect a single declaration into our set of parameters
static void collectGlobalScopeParameter(
ParameterBindingContext* context,
RefPtr<VarDeclBase> varDecl)
@@ -1037,7 +1048,13 @@ static void completeBindingsForParameter(
continue;
}
-
+ else if (kind == LayoutResourceKind::GenericResource)
+ {
+ bindingInfo.space = 0;
+ bindingInfo.count = 0;
+ bindingInfo.index = 0;
+ continue;
+ }
// For now we only auto-generate bindings in space zero
//
@@ -1065,6 +1082,11 @@ static void completeBindingsForParameter(
bindingInfo.space = space;
}
+ if (firstTypeLayout->FindResourceInfo(LayoutResourceKind::GenericResource))
+ {
+
+ }
+
// At this point we should have explicit binding locations chosen for
// all the relevant resource kinds, so we can apply these to the
// declarations:
@@ -1093,15 +1115,22 @@ static void collectGlobalScopeParameters(
ModuleDecl* program)
{
// First enumerate parameters at global scope
- for( auto decl : program->Members )
+ // We collect two things here:
+ // 1. A shader parameter, which is always a variable
+ // 2. A global entry-point generic parameter type (`__generic_param`),
+ // which is a GlobalGenericParamDecl
+ // We collect global generic type parameters in the first pass,
+ // So we can fill in the correct index into ordinary type layouts
+ // for generic types in the second pass.
+ for (auto decl : program->Members)
{
- // A shader parameter is always a variable,
- // so skip declarations that aren't variables.
- auto varDecl = decl.As<VarDeclBase>();
- if (!varDecl)
- continue;
-
- collectGlobalScopeParameter(context, varDecl);
+ if (auto genParamDecl = decl.As<GlobalGenericParamDecl>())
+ collectGlobalGenericParameter(context, genParamDecl);
+ }
+ for (auto decl : program->Members)
+ {
+ if (auto varDecl = decl.As<VarDeclBase>())
+ collectGlobalScopeParameter(context, varDecl);
}
// Next, we need to enumerate the parameters of
@@ -1665,7 +1694,8 @@ void generateParameterBindings(
if (!layoutContext.rules)
return;
- RefPtr<ProgramLayout> programLayout = new ProgramLayout;
+ RefPtr<ProgramLayout> programLayout = new ProgramLayout();
+ targetReq->layout = programLayout;
// Create a context to hold shared state during the process
// of generating parameter bindings
@@ -1680,7 +1710,6 @@ void generateParameterBindings(
context.shared = &sharedContext;
context.translationUnit = nullptr;
context.layoutContext = layoutContext;
-
// Walk through AST to discover all the parameters
collectParameters(&context, compileReq);
@@ -1707,6 +1736,7 @@ void generateParameterBindings(
// If there are any global-scope uniforms, then we need to
// allocate a constant-buffer binding for them here.
ParameterBindingInfo globalConstantBufferBinding;
+ globalConstantBufferBinding.index = 0;
if( anyGlobalUniforms )
{
// TODO: this logic is only correct for D3D targets, where
@@ -1838,8 +1868,191 @@ void generateParameterBindings(
// We now have a bunch of layout information, which we should
// record into a suitable object that represents the program
- programLayout->globalScopeLayout = globalScopeLayout;
- targetReq->layout = programLayout;
+ RefPtr<VarLayout> globalVarLayout = new VarLayout();
+ globalVarLayout->typeLayout = globalScopeLayout;
+ if (anyGlobalUniforms)
+ {
+ auto cbInfo = globalVarLayout->findOrAddResourceInfo(LayoutResourceKind::ConstantBuffer);
+ cbInfo->space = 0;
+ cbInfo->index = globalConstantBufferBinding.index;
+ }
+ programLayout->globalScopeLayout = globalVarLayout;
}
+StructTypeLayout* getGlobalStructLayout(
+ ProgramLayout* programLayout);
+
+RefPtr<ProgramLayout> specializeProgramLayout(
+ TargetRequest * targetReq,
+ ProgramLayout* programLayout,
+ Substitutions * typeSubst)
+{
+ RefPtr<ProgramLayout> newProgramLayout;
+ newProgramLayout = new ProgramLayout();
+ newProgramLayout->bindingForHackSampler = programLayout->bindingForHackSampler;
+ newProgramLayout->hackSamplerVar = programLayout->hackSamplerVar;
+ for (auto & entryPoint : programLayout->entryPoints)
+ {
+ RefPtr<EntryPointLayout> newEntryPoint = new EntryPointLayout(*entryPoint);
+ // TODO: for now just copy existing entry point layouts, but we eventually need to
+ // specialize these as well...
+ newProgramLayout->entryPoints.Add(newEntryPoint);
+ }
+
+ List<RefPtr<TypeLayout>> paramTypeLayouts;
+ auto globalStructLayout = getGlobalStructLayout(programLayout);
+ SLANG_ASSERT(globalStructLayout);
+ RefPtr<StructTypeLayout> structLayout = new StructTypeLayout();
+ RefPtr<TypeLayout> globalScopeLayout = structLayout;
+ structLayout->uniformAlignment = globalStructLayout->uniformAlignment;
+
+ // Try to find rules based on the selected code-generation target
+ auto layoutContext = getInitialLayoutContextForTarget(targetReq);
+
+ // If there was no target, or there are no rules for the target,
+ // then bail out here.
+ if (!layoutContext.rules)
+ return newProgramLayout;
+
+
+ // we need to initialize a layout context to mark used registers
+ SharedParameterBindingContext sharedContext;
+ sharedContext.compileRequest = targetReq->compileRequest;
+ sharedContext.defaultLayoutRules = layoutContext.getRulesFamily();
+ sharedContext.programLayout = programLayout;
+
+ // Create a sub-context to collect parameters that get
+ // declared into the global scope
+ ParameterBindingContext context;
+ context.shared = &sharedContext;
+ context.translationUnit = nullptr;
+ context.layoutContext = layoutContext;
+
+ auto constantBufferRules = context.getRulesFamily()->getConstantBufferRules();
+ structLayout->rules = constantBufferRules;
+
+ UniformLayoutInfo structLayoutInfo;
+ structLayoutInfo.alignment = globalStructLayout->uniformAlignment;
+ structLayoutInfo.size = 0;
+ bool anyUniforms = false;
+ Dictionary<RefPtr<VarLayout>, RefPtr<VarLayout>> varLayoutMapping;
+ for (auto & varLayout : globalStructLayout->fields)
+ {
+ // To recover layout context, we skip generic resources in the first pass
+ // If the var is a generic resource, its resourceInfos will be empty.
+ if (varLayout->resourceInfos.Count() == 0)
+ continue;
+ SLANG_ASSERT(varLayout->resourceInfos.Count() == varLayout->typeLayout->resourceInfos.Count());
+ auto uniformInfo = varLayout->FindResourceInfo(LayoutResourceKind::Uniform);
+ auto tUniformInfo = varLayout->typeLayout->FindResourceInfo(LayoutResourceKind::Uniform);
+ if (uniformInfo)
+ {
+ anyUniforms = true;
+ SLANG_ASSERT(tUniformInfo);
+ structLayoutInfo.size = Math::Max(structLayoutInfo.size, uniformInfo->index + tUniformInfo->count);
+ }
+ for (UInt i = 0; i < varLayout->resourceInfos.Count(); i++)
+ {
+ auto resInfo = varLayout->resourceInfos[i];
+ auto tresInfo = varLayout->typeLayout->resourceInfos[i];
+ SLANG_ASSERT(resInfo.kind == tresInfo.kind);
+ auto usedRangeSet = findUsedRangeSetForSpace(&context, resInfo.space);
+ markSpaceUsed(&context, resInfo.space);
+ usedRangeSet->usedResourceRanges[(int)resInfo.kind].Add(
+ nullptr, // we don't need to track parameter info here
+ resInfo.index,
+ resInfo.index + varLayout->typeLayout->resourceInfos[0].count);
+ }
+ structLayout->fields.Add(varLayout);
+ varLayoutMapping[varLayout] = varLayout;
+ }
+ auto originalGlobalCBufferInfo = programLayout->globalScopeLayout->FindResourceInfo(LayoutResourceKind::ConstantBuffer);
+ VarLayout::ResourceInfo globalCBufferInfo;
+ globalCBufferInfo.kind = LayoutResourceKind::None;
+ globalCBufferInfo.space = 0;
+ globalCBufferInfo.index = 0;
+ if (originalGlobalCBufferInfo)
+ {
+ globalCBufferInfo.kind = LayoutResourceKind::ConstantBuffer;
+ globalCBufferInfo.space = originalGlobalCBufferInfo->space;
+ globalCBufferInfo.index = originalGlobalCBufferInfo->index;
+ }
+ // we have the context restored, can continue to layout the generic variables now
+ for (auto & varLayout : globalStructLayout->fields)
+ {
+ if (varLayout->typeLayout->FindResourceInfo(LayoutResourceKind::GenericResource))
+ {
+ RefPtr<Type> newType = varLayout->typeLayout->type->Substitute(typeSubst).As<Type>();
+ RefPtr<TypeLayout> newTypeLayout = CreateTypeLayout(
+ layoutContext.with(constantBufferRules),
+ newType);
+ auto layoutInfo = newTypeLayout->FindResourceInfo(LayoutResourceKind::Uniform);
+ size_t uniformSize = layoutInfo ? layoutInfo->count : 0;
+ if (uniformSize)
+ {
+ if (globalCBufferInfo.kind == LayoutResourceKind::None)
+ {
+ // user defined a uniform via a global generic type argument
+ // but we have not reserved a binding for the global uniform buffer
+ UInt space = 0;
+ auto usedRangeSet = findUsedRangeSetForSpace(&context, space);
+ globalCBufferInfo.kind = LayoutResourceKind::ConstantBuffer;
+ globalCBufferInfo.index =
+ usedRangeSet->usedResourceRanges[
+ (int)LayoutResourceKind::ConstantBuffer].Allocate(nullptr, 1);
+ globalCBufferInfo.space = space;
+ }
+ }
+ RefPtr<VarLayout> newVarLayout = new VarLayout();
+ RefPtr<ParameterInfo> paramInfo = new ParameterInfo();
+ newVarLayout->varDecl = varLayout->varDecl;
+ newVarLayout->typeLayout = newTypeLayout;
+ paramInfo->varLayouts.Add(newVarLayout);
+ completeBindingsForParameter(&context, paramInfo);
+ // update uniform layout
+
+ if (uniformSize != 0)
+ {
+ // Make sure uniform fields get laid out properly...
+ UniformLayoutInfo fieldInfo(
+ uniformSize,
+ newTypeLayout->uniformAlignment);
+ size_t uniformOffset = layoutContext.getRulesFamily()->getConstantBufferRules()->AddStructField(
+ &structLayoutInfo,
+ fieldInfo);
+ newVarLayout->findOrAddResourceInfo(LayoutResourceKind::Uniform)->index = uniformOffset;
+ anyUniforms = true;
+ }
+ structLayout->fields.Add(newVarLayout);
+ varLayoutMapping[varLayout] = newVarLayout;
+ }
+ }
+ for (auto mapping : globalStructLayout->mapVarToLayout)
+ {
+ RefPtr<VarLayout> updatedVarLayout = mapping.Value;
+ varLayoutMapping.TryGetValue(updatedVarLayout, updatedVarLayout);
+ structLayout->mapVarToLayout[mapping.Key] = updatedVarLayout;
+ }
+
+ // If there are global-scope uniforms, then we need to wrap
+ // up a global constant buffer type layout to hold them
+ RefPtr<VarLayout> globalVarLayout = new VarLayout();
+ if (anyUniforms)
+ {
+ auto globalConstantBufferLayout = createParameterGroupTypeLayout(
+ layoutContext,
+ nullptr,
+ constantBufferRules,
+ constantBufferRules->GetObjectLayout(ShaderParameterKind::ConstantBuffer),
+ structLayout);
+
+ globalScopeLayout = globalConstantBufferLayout;
+ auto cbInfo = globalVarLayout->findOrAddResourceInfo(LayoutResourceKind::ConstantBuffer);
+ *cbInfo = globalCBufferInfo;
+ }
+ globalVarLayout->typeLayout = globalScopeLayout;
+ programLayout->globalScopeLayout = globalVarLayout;
+ newProgramLayout->globalScopeLayout = globalVarLayout;
+ return newProgramLayout;
+}
}
diff --git a/source/slang/parser.cpp b/source/slang/parser.cpp
index 42c763099..0a4360e3f 100644
--- a/source/slang/parser.cpp
+++ b/source/slang/parser.cpp
@@ -95,9 +95,9 @@ namespace Slang
RefPtr<StructDecl> ParseStruct();
RefPtr<ClassDecl> ParseClass();
RefPtr<Stmt> ParseStatement();
- RefPtr<Stmt> ParseBlockStatement();
- RefPtr<DeclStmt> ParseVarDeclrStatement(Modifiers modifiers);
- RefPtr<IfStmt> ParseIfStatement();
+ RefPtr<Stmt> parseBlockStatement();
+ RefPtr<DeclStmt> parseVarDeclrStatement(Modifiers modifiers);
+ RefPtr<IfStmt> parseIfStatement();
RefPtr<ForStmt> ParseForStatement();
RefPtr<WhileStmt> ParseWhileStatement();
RefPtr<DoWhileStmt> ParseDoWhileStatement();
@@ -1034,7 +1034,7 @@ namespace Slang
}
else
{
- decl->Body = parser->ParseBlockStatement();
+ decl->Body = parser->parseBlockStatement();
}
parser->PopScope();
@@ -2172,41 +2172,55 @@ namespace Slang
}
}
- RefPtr<RefObject> ParseAssocType(Parser * parser, void *)
+ void parseOptionalGenericConstraints(Parser * parser, ContainerDecl* decl)
{
- RefPtr<AssocTypeDecl> assocTypeDecl = new AssocTypeDecl();
-
- auto nameToken = parser->ReadToken(TokenType::Identifier);
- assocTypeDecl->nameAndLoc = NameLoc(nameToken);
- assocTypeDecl->loc = nameToken.loc;
if (AdvanceIf(parser, TokenType::Colon))
{
- while (!parser->tokenReader.IsAtEnd())
+ do
{
- auto paramConstraint = new GenericTypeConstraintDecl();
+ RefPtr<GenericTypeConstraintDecl> paramConstraint = new GenericTypeConstraintDecl();
parser->FillPosition(paramConstraint);
- auto paramType = DeclRefType::Create(
+ RefPtr<DeclRefType> paramType = DeclRefType::Create(
parser->getSession(),
- DeclRef<Decl>(assocTypeDecl, nullptr));
+ DeclRef<Decl>(decl, nullptr));
- auto paramTypeExpr = new SharedTypeExpr();
- paramTypeExpr->loc = assocTypeDecl->loc;
+ RefPtr<SharedTypeExpr> paramTypeExpr = new SharedTypeExpr();
+ paramTypeExpr->loc = decl->loc;
paramTypeExpr->base.type = paramType;
paramTypeExpr->type = QualType(getTypeType(paramType));
paramConstraint->sub = TypeExp(paramTypeExpr);
paramConstraint->sup = parser->ParseTypeExp();
- AddMember(assocTypeDecl, paramConstraint);
- if (!AdvanceIf(parser, TokenType::Comma))
- break;
- }
+ AddMember(decl, paramConstraint);
+ } while (AdvanceIf(parser, TokenType::Comma));
}
+ }
+
+ RefPtr<RefObject> parseAssocType(Parser * parser, void *)
+ {
+ RefPtr<AssocTypeDecl> assocTypeDecl = new AssocTypeDecl();
+
+ auto nameToken = parser->ReadToken(TokenType::Identifier);
+ assocTypeDecl->nameAndLoc = NameLoc(nameToken);
+ assocTypeDecl->loc = nameToken.loc;
+ parseOptionalGenericConstraints(parser, assocTypeDecl);
parser->ReadToken(TokenType::Semicolon);
return assocTypeDecl;
}
+ RefPtr<RefObject> parseGlobalGenericParamDecl(Parser * parser, void *)
+ {
+ RefPtr<GlobalGenericParamDecl> genParamDecl = new GlobalGenericParamDecl();
+ auto nameToken = parser->ReadToken(TokenType::Identifier);
+ genParamDecl->nameAndLoc = NameLoc(nameToken);
+ genParamDecl->loc = nameToken.loc;
+ parseOptionalGenericConstraints(parser, genParamDecl);
+ parser->ReadToken(TokenType::Semicolon);
+ return genParamDecl;
+ }
+
static RefPtr<RefObject> parseInterfaceDecl(Parser* parser, void* /*userData*/)
{
RefPtr<InterfaceDecl> decl = new InterfaceDecl();
@@ -2220,7 +2234,7 @@ namespace Slang
return decl;
}
- static RefPtr<RefObject> ParseConstructorDecl(Parser* parser, void* /*userData*/)
+ static RefPtr<RefObject> parseConstructorDecl(Parser* parser, void* /*userData*/)
{
RefPtr<ConstructorDecl> decl = new ConstructorDecl();
parser->FillPosition(decl.Ptr());
@@ -2243,7 +2257,7 @@ namespace Slang
}
else
{
- decl->Body = parser->ParseBlockStatement();
+ decl->Body = parser->parseBlockStatement();
}
return decl;
}
@@ -2271,7 +2285,7 @@ namespace Slang
if( parser->tokenReader.PeekTokenType() == TokenType::LBrace )
{
- decl->Body = parser->ParseBlockStatement();
+ decl->Body = parser->parseBlockStatement();
}
else
{
@@ -2664,7 +2678,7 @@ namespace Slang
parser->ReadToken(TokenType::LParent);
stmt->condition = parser->ParseExpression();
parser->ReadToken(TokenType::RParent);
- stmt->body = parser->ParseBlockStatement();
+ stmt->body = parser->parseBlockStatement();
return stmt;
}
@@ -2788,11 +2802,11 @@ namespace Slang
RefPtr<Stmt> statement;
if (LookAheadToken(TokenType::LBrace))
- statement = ParseBlockStatement();
+ statement = parseBlockStatement();
else if (peekTypeName(this))
- statement = ParseVarDeclrStatement(modifiers);
+ statement = parseVarDeclrStatement(modifiers);
else if (LookAheadToken("if"))
- statement = ParseIfStatement();
+ statement = parseIfStatement();
else if (LookAheadToken("for"))
statement = ParseForStatement();
else if (LookAheadToken("while"))
@@ -2852,7 +2866,7 @@ namespace Slang
// Note: the declaration will consume any modifiers
// that had been in place on the statement.
tokenReader.mCursor = startPos;
- statement = ParseVarDeclrStatement(modifiers);
+ statement = parseVarDeclrStatement(modifiers);
return statement;
}
@@ -2885,7 +2899,7 @@ namespace Slang
return statement;
}
- RefPtr<Stmt> Parser::ParseBlockStatement()
+ RefPtr<Stmt> Parser::parseBlockStatement()
{
// If we are being asked not to check things *and* we haven't
// seen any `import` declarations yet, then we can safely assume
@@ -2983,7 +2997,7 @@ namespace Slang
return blockStatement;
}
- RefPtr<DeclStmt> Parser::ParseVarDeclrStatement(
+ RefPtr<DeclStmt> Parser::parseVarDeclrStatement(
Modifiers modifiers)
{
RefPtr<DeclStmt>varDeclrStatement = new DeclStmt();
@@ -2994,7 +3008,7 @@ namespace Slang
return varDeclrStatement;
}
- RefPtr<IfStmt> Parser::ParseIfStatement()
+ RefPtr<IfStmt> Parser::parseIfStatement()
{
RefPtr<IfStmt> ifStatement = new IfStmt();
FillPosition(ifStatement.Ptr());
@@ -3045,7 +3059,7 @@ namespace Slang
ReadToken(TokenType::LParent);
if (peekTypeName(this))
{
- stmt->InitialStatement = ParseVarDeclrStatement(Modifiers());
+ stmt->InitialStatement = parseVarDeclrStatement(Modifiers());
}
else
{
@@ -3107,7 +3121,7 @@ namespace Slang
return breakStatement;
}
- RefPtr<ContinueStmt> Parser::ParseContinueStatement()
+ RefPtr<ContinueStmt> Parser::ParseContinueStatement()
{
RefPtr<ContinueStmt> continueStatement = new ContinueStmt();
FillPosition(continueStatement.Ptr());
@@ -4128,17 +4142,18 @@ namespace Slang
// Add syntax for declaration keywords
#define DECL(KEYWORD, CALLBACK) \
addBuiltinSyntax<Decl>(session, scope, #KEYWORD, &CALLBACK)
- DECL(typedef, ParseTypeDef);
- DECL(associatedtype,ParseAssocType);
- DECL(cbuffer, parseHLSLCBufferDecl);
- DECL(tbuffer, parseHLSLTBufferDecl);
- DECL(__generic, ParseGenericDecl);
- DECL(__extension, ParseExtensionDecl);
- DECL(__init, ParseConstructorDecl);
- DECL(__subscript, ParseSubscriptDecl);
- DECL(interface, parseInterfaceDecl);
- DECL(syntax, parseSyntaxDecl);
- DECL(__import, parseImportDecl);
+ DECL(typedef, ParseTypeDef);
+ DECL(associatedtype, parseAssocType);
+ DECL(__generic_param, parseGlobalGenericParamDecl);
+ DECL(cbuffer, parseHLSLCBufferDecl);
+ DECL(tbuffer, parseHLSLTBufferDecl);
+ DECL(__generic, ParseGenericDecl);
+ DECL(__extension, ParseExtensionDecl);
+ DECL(__init, parseConstructorDecl);
+ DECL(__subscript, ParseSubscriptDecl);
+ DECL(interface, parseInterfaceDecl);
+ DECL(syntax, parseSyntaxDecl);
+ DECL(__import, parseImportDecl);
#undef DECL
diff --git a/source/slang/reflection.cpp b/source/slang/reflection.cpp
index 9fc032c76..14199f126 100644
--- a/source/slang/reflection.cpp
+++ b/source/slang/reflection.cpp
@@ -886,3 +886,22 @@ SLANG_API SlangReflectionEntryPoint* spReflection_getEntryPointByIndex(SlangRefl
return convert(program->entryPoints[(int) index].Ptr());
}
+
+SLANG_API SlangUInt spReflection_getGlobalConstantBufferBinding(SlangReflection* inProgram)
+{
+ auto program = convert(inProgram);
+ if (!program) return 0;
+ auto cb = program->globalScopeLayout->FindResourceInfo(LayoutResourceKind::ConstantBuffer);
+ if (!cb) return 0;
+ return cb->index;
+}
+
+SLANG_API size_t spReflection_getGlobalConstantBufferSize(SlangReflection* inProgram)
+{
+ auto program = convert(inProgram);
+ if (!program) return 0;
+ auto structLayout = getGlobalStructLayout(program);
+ auto uniform = structLayout->FindResourceInfo(LayoutResourceKind::Uniform);
+ if (!uniform) return 0;
+ return uniform->count;
+}
diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp
index 17f8ea96d..6a103fc2d 100644
--- a/source/slang/slang.cpp
+++ b/source/slang/slang.cpp
@@ -400,14 +400,16 @@ void CompileRequest::addTranslationUnitSourceFile(
int CompileRequest::addEntryPoint(
int translationUnitIndex,
String const& name,
- Profile entryPointProfile)
+ Profile entryPointProfile,
+ List<String> const & genericTypeNames)
{
RefPtr<EntryPointRequest> entryPoint = new EntryPointRequest();
entryPoint->compileRequest = this;
entryPoint->name = getNamePool()->getName(name);
entryPoint->profile = entryPointProfile;
entryPoint->translationUnitIndex = translationUnitIndex;
-
+ for (auto typeName : genericTypeNames)
+ entryPoint->genericParameterTypeNames.Add(getNamePool()->getName(typeName));
auto translationUnit = translationUnits[translationUnitIndex].Ptr();
translationUnit->entryPoints.Add(entryPoint);
@@ -909,7 +911,31 @@ SLANG_API int spAddEntryPoint(
return req->addEntryPoint(
translationUnitIndex,
name,
- Slang::Profile(Slang::Profile::RawVal(profile)));
+ Slang::Profile(Slang::Profile::RawVal(profile)),
+ Slang::List<Slang::String>());
+}
+
+SLANG_API int spAddEntryPointEx(
+ SlangCompileRequest* request,
+ int translationUnitIndex,
+ char const* name,
+ SlangProfileID profile,
+ int genericParamTypeNameCount,
+ char const ** genericParamTypeNames)
+{
+ if (!request) return -1;
+ auto req = REQ(request);
+ if (!name) return -1;
+ if (translationUnitIndex < 0) return -1;
+ if (Slang::UInt(translationUnitIndex) >= req->translationUnits.Count()) return -1;
+ Slang::List<Slang::String> typeNames;
+ for (int i = 0; i < genericParamTypeNameCount; i++)
+ typeNames.Add(genericParamTypeNames[i]);
+ return req->addEntryPoint(
+ translationUnitIndex,
+ name,
+ Slang::Profile(Slang::Profile::RawVal(profile)),
+ typeNames);
}
diff --git a/source/slang/syntax-base-defs.h b/source/slang/syntax-base-defs.h
index 3c7e8c5ae..fdb2694a9 100644
--- a/source/slang/syntax-base-defs.h
+++ b/source/slang/syntax-base-defs.h
@@ -197,6 +197,38 @@ SYNTAX_CLASS(ThisTypeSubstitution, Substitutions)
)
END_SYNTAX_CLASS()
+SYNTAX_CLASS(GlobalGenericParamSubstitution, Substitutions)
+ // the __generic_param decl to be substituted
+ DECL_FIELD(GlobalGenericParamDecl*, paramDecl)
+ // the actual type to substitute in
+ SYNTAX_FIELD(RefPtr<Val>, actualType)
+
+RAW(
+ // Apply a set of substitutions to the bindings in this substitution
+ virtual RefPtr<Substitutions> SubstituteImpl(Substitutions* subst, int* ioDiff) override;
+
+ // Check if these are equivalent substitutiosn to another set
+ virtual bool Equals(Substitutions* subst) override;
+ virtual bool operator == (const Substitutions & subst) override
+ {
+ return Equals(const_cast<Substitutions*>(&subst));
+ }
+ virtual int GetHashCode() const override
+ {
+ int rs = actualType->GetHashCode();
+ for (auto && v : witnessTables)
+ {
+ rs = combineHash(rs, v.Key->GetHashCode());
+ rs = combineHash(rs, v.Value->GetHashCode());
+ }
+ return rs;
+ }
+ typedef List<KeyValuePair<RefPtr<Type>, RefPtr<Val>>> WitnessTableLookupTable;
+)
+ // The witness tables for each interface this actual type implements
+ SYNTAX_FIELD(WitnessTableLookupTable, witnessTables)
+END_SYNTAX_CLASS()
+
ABSTRACT_SYNTAX_CLASS(SyntaxNode, SyntaxNodeBase)
END_SYNTAX_CLASS()
diff --git a/source/slang/syntax.cpp b/source/slang/syntax.cpp
index e5fc8dfa3..fa9c88051 100644
--- a/source/slang/syntax.cpp
+++ b/source/slang/syntax.cpp
@@ -93,6 +93,7 @@ ABSTRACT_SYNTAX_CLASS(Expr, SyntaxNode);
ABSTRACT_SYNTAX_CLASS(Substitutions, SyntaxNode);
ABSTRACT_SYNTAX_CLASS(GenericSubstitution, Substitutions);
ABSTRACT_SYNTAX_CLASS(ThisTypeSubstitution, Substitutions);
+ABSTRACT_SYNTAX_CLASS(GlobalGenericParamSubstitution, Substitutions);
#include "expr-defs.h"
#include "decl-defs.h"
@@ -488,6 +489,20 @@ void Type::accept(IValVisitor* visitor, void* extra)
}
}
}
+ else if (auto globalGenParam = dynamic_cast<GlobalGenericParamDecl*>(declRef.getDecl()))
+ {
+ // search for a substitution that might apply to us
+ for (auto s = subst; s; s = s->outer.Ptr())
+ {
+ if (auto genericSubst = dynamic_cast<GlobalGenericParamSubstitution*>(s))
+ {
+ if (genericSubst->paramDecl == globalGenParam)
+ {
+ return genericSubst->actualType;
+ }
+ }
+ }
+ }
int diff = 0;
DeclRef<Decl> substDeclRef = declRef.SubstituteImpl(subst, &diff);
@@ -1208,6 +1223,35 @@ void Type::accept(IValVisitor* visitor, void* extra)
return false;
}
+ RefPtr<Substitutions> GlobalGenericParamSubstitution::SubstituteImpl(Substitutions* /*subst*/, int* /*ioDiff*/)
+ {
+ // we will never replace values for this type of substitution
+ return this;
+ }
+
+ bool GlobalGenericParamSubstitution::Equals(Substitutions* subst)
+ {
+ if (!subst)
+ return false;
+ if (auto genSubst = dynamic_cast<GlobalGenericParamSubstitution*>(subst))
+ {
+ if (paramDecl != genSubst->paramDecl)
+ return false;
+ if (!actualType->EqualsVal(genSubst->actualType))
+ return false;
+ if (witnessTables.Count() != genSubst->witnessTables.Count())
+ return false;
+ for (UInt i = 0; i < witnessTables.Count(); i++)
+ {
+ if (!witnessTables[i].Key->Equals(genSubst->witnessTables[i].Key))
+ return false;
+ if (!witnessTables[i].Value->EqualsVal(genSubst->witnessTables[i].Value))
+ return false;
+ }
+ return true;
+ }
+ return false;
+ }
// DeclRefBase
@@ -1564,6 +1608,24 @@ void Type::accept(IValVisitor* visitor, void* extra)
return genericSubst->args[index + ordinaryParamCount];
}
}
+ else if (auto globalGenParamSubst = dynamic_cast<GlobalGenericParamSubstitution*>(s))
+ {
+ // we have a GlobalGenericParamSubstitution, this substitution will provide
+ // a concrete IRWitnessTable for a generic global variable
+ auto supType = GetSup(genConstraintDecl);
+
+ // check if the substitution is really about this global generic type parameter
+ if (globalGenParamSubst->paramDecl != genConstraintDecl.getDecl()->ParentDecl)
+ continue;
+
+ // find witness table for the required interface
+ for (auto witness : globalGenParamSubst->witnessTables)
+ if (witness.Key->EqualsVal(supType))
+ {
+ (*ioDiff)++;
+ return witness.Value;
+ }
+ }
}
}
RefPtr<DeclaredSubtypeWitness> rs = new DeclaredSubtypeWitness();
diff --git a/source/slang/syntax.h b/source/slang/syntax.h
index 46beca2d9..b4d550ef5 100644
--- a/source/slang/syntax.h
+++ b/source/slang/syntax.h
@@ -1073,7 +1073,7 @@ namespace Slang
{
return declRef.Substitute(declRef.getDecl()->base.type);
}
-
+
inline RefPtr<Type> GetType(DeclRef<TypeDefDecl> const& declRef)
{
return declRef.Substitute(declRef.getDecl()->type.Ptr());
diff --git a/source/slang/type-layout.cpp b/source/slang/type-layout.cpp
index 8fa790dd8..30b2ee01a 100644
--- a/source/slang/type-layout.cpp
+++ b/source/slang/type-layout.cpp
@@ -1222,6 +1222,11 @@ SimpleLayoutInfo GetLayoutImpl(
return GetLayoutImpl(subContext, type, outTypeLayout, SimpleLayoutInfo());
}
+int findGenericParam(List<RefPtr<GenericParamLayout>> & genericParameters, GlobalGenericParamDecl * decl)
+{
+ return (int)genericParameters.FindFirst([=](RefPtr<GenericParamLayout> & x) {return x->decl.Ptr() == decl; });
+}
+
SimpleLayoutInfo GetLayoutImpl(
TypeLayoutContext const& context,
Type* type,
@@ -1599,6 +1604,25 @@ SimpleLayoutInfo GetLayoutImpl(
return info;
}
+ else if (auto globalGenParam = declRef.As<GlobalGenericParamDecl>())
+ {
+ SimpleLayoutInfo info;
+ info.alignment = 0;
+ info.size = 0;
+ info.kind = LayoutResourceKind::GenericResource;
+ if (outTypeLayout)
+ {
+ auto genParamTypeLayout = new GenericParamTypeLayout();
+ // we should have already populated ProgramLayout::genericEntryPointParams list at this point,
+ // so we can find the index of this generic param decl in the list
+ genParamTypeLayout->type = type;
+ genParamTypeLayout->paramIndex = findGenericParam(context.targetReq->layout->globalGenericParams, genParamTypeLayout->getGlobalGenericParamDecl());
+ genParamTypeLayout->rules = rules;
+ genParamTypeLayout->findOrAddResourceInfo(LayoutResourceKind::GenericResource)->count++;
+ *outTypeLayout = genParamTypeLayout;
+ }
+ return info;
+ }
}
else if (auto errorType = type->As<ErrorType>())
{
@@ -1667,4 +1691,12 @@ RefPtr<TypeLayout> CreateTypeLayout(
return CreateTypeLayout(context, type, SimpleLayoutInfo());
}
+RefPtr<GlobalGenericParamDecl> GenericParamTypeLayout::getGlobalGenericParamDecl()
+{
+ auto declRefType = type->AsDeclRefType();
+ SLANG_ASSERT(declRefType);
+ auto rsDeclRef = declRefType->declRef.As<GlobalGenericParamDecl>();
+ return rsDeclRef.getDecl();
+}
+
} // namespace Slang
diff --git a/source/slang/type-layout.h b/source/slang/type-layout.h
index 363b01486..4ce6dc355 100644
--- a/source/slang/type-layout.h
+++ b/source/slang/type-layout.h
@@ -220,7 +220,7 @@ typedef unsigned int VarLayoutFlags;
enum VarLayoutFlag : VarLayoutFlags
{
IsRedeclaration = 1 << 0, ///< This is a redeclaration of some shader parameter
- HasSemantic = 1 << 1,
+ HasSemantic = 1 << 1
};
// A reified layout for a particular variable, field, etc.
@@ -358,6 +358,13 @@ public:
Dictionary<Decl*, RefPtr<VarLayout>> mapVarToLayout;
};
+class GenericParamTypeLayout : public TypeLayout
+{
+public:
+ RefPtr<GlobalGenericParamDecl> getGlobalGenericParamDecl();
+ int paramIndex = 0;
+};
+
// Layout information for a single shader entry point
// within a program
//
@@ -386,6 +393,13 @@ public:
unsigned flags = 0;
};
+class GenericParamLayout : public Layout
+{
+public:
+ RefPtr<GlobalGenericParamDecl> decl;
+ int index;
+};
+
// Layout information for the global scope of a program
class ProgramLayout : public Layout
{
@@ -403,13 +417,15 @@ public:
// (since a constant buffer will have to be allocated
// to store them).
//
- RefPtr<TypeLayout> globalScopeLayout;
+ RefPtr<VarLayout> globalScopeLayout;
// We catalog the requested entry points here,
// and any entry-point-specific parameter data
// will (eventually) belong there...
List<RefPtr<EntryPointLayout>> entryPoints;
+ List<RefPtr<GenericParamLayout>> globalGenericParams;
+
// HACK: binding to use when we have to create
// a dummy sampler just to appease glslang
int bindingForHackSampler = 0;
diff --git a/tests/compute/global-type-param.slang b/tests/compute/global-type-param.slang
new file mode 100644
index 000000000..301ef1021
--- /dev/null
+++ b/tests/compute/global-type-param.slang
@@ -0,0 +1,30 @@
+//TEST(smoke,compute):COMPARE_COMPUTE:-xslang -use-ir
+//TEST_INPUT:ubuffer(data=[0], stride=4):dxbinding(0),glbinding(0),out
+//TEST_INPUT:type Impl
+
+RWStructuredBuffer<float> outputBuffer;
+
+interface IBase
+{
+ float compute();
+}
+
+struct Impl : IBase
+{
+ float compute()
+ {
+ return 1.0;
+ }
+};
+
+__generic_param TImpl : IBase;
+
+TImpl impl;
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ uint tid = dispatchThreadID.x;
+ float outVal = impl.compute();
+ outputBuffer[tid] = outVal;
+} \ No newline at end of file
diff --git a/tests/compute/global-type-param.slang.expected.txt b/tests/compute/global-type-param.slang.expected.txt
new file mode 100644
index 000000000..47b9ba0c8
--- /dev/null
+++ b/tests/compute/global-type-param.slang.expected.txt
@@ -0,0 +1 @@
+3F800000 \ No newline at end of file
diff --git a/tests/compute/global-type-param1.slang b/tests/compute/global-type-param1.slang
new file mode 100644
index 000000000..c9b754aa3
--- /dev/null
+++ b/tests/compute/global-type-param1.slang
@@ -0,0 +1,46 @@
+//TEST(smoke,compute):COMPARE_COMPUTE:-xslang -use-ir
+//TEST_INPUT: cbuffer(data=[0.5 0 0 0 1.0], stride=4):dxbinding(0),glbinding(0)
+//TEST_INPUT: cbuffer(data=[1.0], stride=4):dxbinding(1),glbinding(1)
+//TEST_INPUT: Texture2D(size=4, content = zero) : dxbinding(0),glbinding(0)
+//TEST_INPUT: Texture2D(size=4, content = one) : dxbinding(1),glbinding(1)
+//TEST_INPUT: Sampler : dxbinding(0),glbinding(0,1,2,3,4,5,6)
+//TEST_INPUT: Sampler : dxbinding(1),glbinding(0,1,2,3,4,5,6)
+//TEST_INPUT: ubuffer(data=[0], stride=4):dxbinding(0),glbinding(0),out
+//TEST_INPUT: type Impl
+
+RWStructuredBuffer<float> outputBuffer;
+
+interface IBase
+{
+ float compute();
+}
+
+struct Impl : IBase
+{
+ float base; // = 1.0
+ Texture2D tex;
+ SamplerState sampler;
+ float compute()
+ {
+ return tex.SampleLevel(sampler, float2(0.0), 0.0).x + base;
+ }
+};
+
+__generic_param TImpl : IBase;
+
+TImpl impl;
+
+float base0; // = 0.5
+
+Texture2D tex1; // = 0.0
+SamplerState sampler;
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ uint tid = dispatchThreadID.x;
+ float b0 = tex1.SampleLevel(sampler, float2(0.0), 0.0).x + base0; // = 0.5
+ float outVal = impl.compute(); // = 2.0
+ outVal = b0 / outVal; // = 0.25
+ outputBuffer[tid] = outVal;
+} \ No newline at end of file
diff --git a/tests/compute/global-type-param1.slang.expected.txt b/tests/compute/global-type-param1.slang.expected.txt
new file mode 100644
index 000000000..4846e7be2
--- /dev/null
+++ b/tests/compute/global-type-param1.slang.expected.txt
@@ -0,0 +1 @@
+3E800000
diff --git a/tests/compute/global-type-param2.slang b/tests/compute/global-type-param2.slang
new file mode 100644
index 000000000..b54f4c430
--- /dev/null
+++ b/tests/compute/global-type-param2.slang
@@ -0,0 +1,61 @@
+//TEST(smoke,compute):COMPARE_COMPUTE:-xslang -use-ir
+//TEST_INPUT: cbuffer(data=[0.5 0 0 0], stride=4):dxbinding(0),glbinding(0)
+//TEST_INPUT: cbuffer(data=[1.0], stride=4):dxbinding(1),glbinding(1)
+//TEST_INPUT: Texture2D(size=4, content = zero) : dxbinding(0),glbinding(0)
+//TEST_INPUT: Texture2D(size=4, content = one) : dxbinding(1),glbinding(1)
+//TEST_INPUT: Sampler : dxbinding(0),glbinding(0,1,2,3,4,5,6)
+//TEST_INPUT: Sampler : dxbinding(1),glbinding(0,1,2,3,4,5,6)
+//TEST_INPUT: ubuffer(data=[0], stride=4):dxbinding(0),glbinding(0),out
+//TEST_INPUT: type Impl
+
+
+/* Testing this scenario:
+The ProgramLayout before specializeIRForEntryPoint() has no global cbuffer
+allocated because the program before specialization does not define any
+global uniform variables.
+
+However, after specialization, we find ourselves needing a global constant
+buffer. The compiler should allocate the next available constant buffer slot
+(here c1 because c0 is already used by the explicit cbuffer decl `existingBuffer`)
+for the newly generated global constant buffer.
+*/
+
+RWStructuredBuffer<float> outputBuffer;
+
+interface IBase
+{
+ float compute();
+}
+
+struct Impl : IBase
+{
+ float base; // = 1.0
+ Texture2D tex;
+ SamplerState sampler;
+ float compute()
+ {
+ return tex.SampleLevel(sampler, float2(0.0), 0.0).x + base;
+ }
+};
+
+__generic_param TImpl : IBase;
+
+TImpl impl;
+
+// at binding c0:
+cbuffer existingBuffer
+{
+ float base0; // = 0.5
+}
+Texture2D tex1; // = 0.0
+SamplerState sampler;
+
+[numthreads(1, 1, 1)]
+void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
+{
+ uint tid = dispatchThreadID.x;
+ float b0 = tex1.SampleLevel(sampler, float2(0.0), 0.0).x + base0; // = 0.5
+ float outVal = impl.compute(); // = 2.0
+ outVal = b0 / outVal; // = 0.25
+ outputBuffer[tid] = outVal;
+} \ No newline at end of file
diff --git a/tests/compute/global-type-param2.slang.expected.txt b/tests/compute/global-type-param2.slang.expected.txt
new file mode 100644
index 000000000..4846e7be2
--- /dev/null
+++ b/tests/compute/global-type-param2.slang.expected.txt
@@ -0,0 +1 @@
+3E800000
diff --git a/tools/render-test/main.cpp b/tools/render-test/main.cpp
index cb0eb927d..51a96436f 100644
--- a/tools/render-test/main.cpp
+++ b/tools/render-test/main.cpp
@@ -117,6 +117,7 @@ Error initializeShaders(
compileRequest.computeShader.name = computeEntryPointName;
compileRequest.computeShader.profile = computeProfileName;
}
+ compileRequest.entryPointTypeArguments = gShaderInputLayout.globalTypeArguments;
gShaderProgram = shaderCompiler->compileProgram(compileRequest);
if( !gShaderProgram )
{
diff --git a/tools/render-test/render-d3d11.cpp b/tools/render-test/render-d3d11.cpp
index 9bac24094..cdd6c778e 100644
--- a/tools/render-test/render-d3d11.cpp
+++ b/tools/render-test/render-d3d11.cpp
@@ -455,10 +455,18 @@ public:
ID3D11Buffer * buffer = nullptr;
};
+ UInt RoundUpToAlignment(UInt size, UInt alignment)
+ {
+ if (size % alignment)
+ return (size / alignment + 1) * alignment;
+ else
+ return Math::Max(size, alignment);
+ }
+
virtual Buffer* createBuffer(BufferDesc const& desc) override
{
D3D11_BUFFER_DESC dxBufferDesc = { 0 };
- dxBufferDesc.ByteWidth = (UINT) desc.size;
+ dxBufferDesc.ByteWidth = (UINT)RoundUpToAlignment(desc.size, 256);
switch( desc.flavor )
{
@@ -773,7 +781,11 @@ public:
{
auto dxContext = dxImmediateContext;
D3D11_BUFFER_DESC desc = {0};
- desc.ByteWidth = (UINT)(bufferData.Count() * sizeof(unsigned int));
+ List<unsigned int> newBuffer;
+ desc.ByteWidth = (UINT)RoundUpToAlignment((bufferData.Count() * sizeof(unsigned int)), 256);
+ newBuffer.SetSize(desc.ByteWidth / sizeof(unsigned int));
+ for (UInt i = 0; i < bufferData.Count(); i++)
+ newBuffer[i] = bufferData[i];
if (bufferDesc.type == InputBufferType::ConstantBuffer)
{
desc.Usage = D3D11_USAGE_DEFAULT;
@@ -794,7 +806,7 @@ public:
}
}
D3D11_SUBRESOURCE_DATA data = {0};
- data.pSysMem = bufferData.Buffer();
+ data.pSysMem = newBuffer.Buffer();
dxDevice->CreateBuffer(&desc, &data, &bufferOut);
int elemSize = bufferDesc.stride <= 0 ? 1 : bufferDesc.stride;
if (bufferDesc.type == InputBufferType::StorageBuffer)
@@ -1091,7 +1103,7 @@ public:
D3D11_BUFFER_DESC bufDesc;
memset(&bufDesc, 0, sizeof(bufDesc));
bufDesc.BindFlags = 0;
- bufDesc.ByteWidth = binding.bufferLength;
+ bufDesc.ByteWidth = (UINT)RoundUpToAlignment(binding.bufferLength, 256);
bufDesc.CPUAccessFlags = D3D11_CPU_ACCESS_READ;
bufDesc.Usage = D3D11_USAGE_STAGING;
dxDevice->CreateBuffer(&bufDesc, nullptr, &stageBuf);
diff --git a/tools/render-test/render.h b/tools/render-test/render.h
index dec48cda4..174ba0b7b 100644
--- a/tools/render-test/render.h
+++ b/tools/render-test/render.h
@@ -31,6 +31,7 @@ struct ShaderCompileRequest
EntryPoint vertexShader;
EntryPoint fragmentShader;
EntryPoint computeShader;
+ Slang::List<Slang::String> entryPointTypeArguments;
};
class ShaderCompiler
diff --git a/tools/render-test/shader-input-layout.cpp b/tools/render-test/shader-input-layout.cpp
index ef78fe3d5..01328eabd 100644
--- a/tools/render-test/shader-input-layout.cpp
+++ b/tools/render-test/shader-input-layout.cpp
@@ -7,6 +7,7 @@ namespace renderer_test
void ShaderInputLayout::Parse(const char * source)
{
entries.Clear();
+ globalTypeArguments.Clear();
auto lines = Split(source, '\n');
for (auto & line : lines)
{
@@ -16,200 +17,208 @@ namespace renderer_test
TokenReader parser(lineContent);
try
{
- ShaderInputLayoutEntry entry;
-
- if (parser.LookAhead("cbuffer"))
- {
- entry.type = ShaderInputType::Buffer;
- entry.bufferDesc.type = InputBufferType::ConstantBuffer;
- }
- else if (parser.LookAhead("ubuffer"))
- {
- entry.type = ShaderInputType::Buffer;
- entry.bufferDesc.type = InputBufferType::StorageBuffer;
- }
- else if (parser.LookAhead("Texture1D"))
- {
- entry.type = ShaderInputType::Texture;
- entry.textureDesc.dimension = 1;
- }
- else if (parser.LookAhead("Texture2D"))
- {
- entry.type = ShaderInputType::Texture;
- entry.textureDesc.dimension = 2;
- }
- else if (parser.LookAhead("Texture3D"))
- {
- entry.type = ShaderInputType::Texture;
- entry.textureDesc.dimension = 3;
- }
- else if (parser.LookAhead("TextureCube"))
- {
- entry.type = ShaderInputType::Texture;
- entry.textureDesc.dimension = 2;
- entry.textureDesc.isCube = true;
- }
- else if (parser.LookAhead("RWTexture1D"))
- {
- entry.type = ShaderInputType::Texture;
- entry.textureDesc.dimension = 1;
- entry.textureDesc.isRWTexture = true;
- }
- else if (parser.LookAhead("RWTexture2D"))
- {
- entry.type = ShaderInputType::Texture;
- entry.textureDesc.dimension = 2;
- entry.textureDesc.isRWTexture = true;
- }
- else if (parser.LookAhead("RWTexture3D"))
- {
- entry.type = ShaderInputType::Texture;
- entry.textureDesc.dimension = 3;
- entry.textureDesc.isRWTexture = true;
- }
- else if (parser.LookAhead("RWTextureCube"))
- {
- entry.type = ShaderInputType::Texture;
- entry.textureDesc.dimension = 2;
- entry.textureDesc.isCube = true;
- entry.textureDesc.isRWTexture = true;
- }
- else if (parser.LookAhead("Sampler"))
- {
- entry.type = ShaderInputType::Sampler;
- }
- else if (parser.LookAhead("Sampler1D"))
- {
- entry.type = ShaderInputType::CombinedTextureSampler;
- entry.textureDesc.dimension = 1;
- }
- else if (parser.LookAhead("Sampler2D"))
- {
- entry.type = ShaderInputType::CombinedTextureSampler;
- entry.textureDesc.dimension = 2;
- }
- else if (parser.LookAhead("Sampler3D"))
- {
- entry.type = ShaderInputType::CombinedTextureSampler;
- entry.textureDesc.dimension = 3;
- }
- else if (parser.LookAhead("SamplerCube"))
+ if (parser.LookAhead("type"))
{
- entry.type = ShaderInputType::CombinedTextureSampler;
- entry.textureDesc.dimension = 2;
- entry.textureDesc.isCube = true;
+ parser.ReadToken();
+ globalTypeArguments.Add(parser.ReadWord());
}
- else if (parser.LookAhead("render_targets"))
+ else
{
- numRenderTargets = parser.ReadInt();
- continue;
- }
- parser.ReadToken();
- // parse options
- if (parser.LookAhead("("))
- {
- parser.Read("(");
- while (!parser.IsEnd() && !parser.LookAhead(")"))
+ ShaderInputLayoutEntry entry;
+
+ if (parser.LookAhead("cbuffer"))
{
- auto word = parser.ReadWord();
- if (word == "depth")
- {
- entry.textureDesc.isDepthTexture = true;
- }
- else if (word == "depthCompare")
- {
- entry.samplerDesc.isCompareSampler = true;
- }
- else if (word == "arrayLength")
- {
- parser.Read("=");
- entry.textureDesc.arrayLength = parser.ReadInt();
- }
- else if (word == "stride")
- {
- parser.Read("=");
- entry.bufferDesc.stride = parser.ReadInt();
- }
- else if (word == "size")
- {
- parser.Read("=");
- entry.textureDesc.size = parser.ReadInt();
- }
- else if (word == "data")
+ entry.type = ShaderInputType::Buffer;
+ entry.bufferDesc.type = InputBufferType::ConstantBuffer;
+ }
+ else if (parser.LookAhead("ubuffer"))
+ {
+ entry.type = ShaderInputType::Buffer;
+ entry.bufferDesc.type = InputBufferType::StorageBuffer;
+ }
+ else if (parser.LookAhead("Texture1D"))
+ {
+ entry.type = ShaderInputType::Texture;
+ entry.textureDesc.dimension = 1;
+ }
+ else if (parser.LookAhead("Texture2D"))
+ {
+ entry.type = ShaderInputType::Texture;
+ entry.textureDesc.dimension = 2;
+ }
+ else if (parser.LookAhead("Texture3D"))
+ {
+ entry.type = ShaderInputType::Texture;
+ entry.textureDesc.dimension = 3;
+ }
+ else if (parser.LookAhead("TextureCube"))
+ {
+ entry.type = ShaderInputType::Texture;
+ entry.textureDesc.dimension = 2;
+ entry.textureDesc.isCube = true;
+ }
+ else if (parser.LookAhead("RWTexture1D"))
+ {
+ entry.type = ShaderInputType::Texture;
+ entry.textureDesc.dimension = 1;
+ entry.textureDesc.isRWTexture = true;
+ }
+ else if (parser.LookAhead("RWTexture2D"))
+ {
+ entry.type = ShaderInputType::Texture;
+ entry.textureDesc.dimension = 2;
+ entry.textureDesc.isRWTexture = true;
+ }
+ else if (parser.LookAhead("RWTexture3D"))
+ {
+ entry.type = ShaderInputType::Texture;
+ entry.textureDesc.dimension = 3;
+ entry.textureDesc.isRWTexture = true;
+ }
+ else if (parser.LookAhead("RWTextureCube"))
+ {
+ entry.type = ShaderInputType::Texture;
+ entry.textureDesc.dimension = 2;
+ entry.textureDesc.isCube = true;
+ entry.textureDesc.isRWTexture = true;
+ }
+ else if (parser.LookAhead("Sampler"))
+ {
+ entry.type = ShaderInputType::Sampler;
+ }
+ else if (parser.LookAhead("Sampler1D"))
+ {
+ entry.type = ShaderInputType::CombinedTextureSampler;
+ entry.textureDesc.dimension = 1;
+ }
+ else if (parser.LookAhead("Sampler2D"))
+ {
+ entry.type = ShaderInputType::CombinedTextureSampler;
+ entry.textureDesc.dimension = 2;
+ }
+ else if (parser.LookAhead("Sampler3D"))
+ {
+ entry.type = ShaderInputType::CombinedTextureSampler;
+ entry.textureDesc.dimension = 3;
+ }
+ else if (parser.LookAhead("SamplerCube"))
+ {
+ entry.type = ShaderInputType::CombinedTextureSampler;
+ entry.textureDesc.dimension = 2;
+ entry.textureDesc.isCube = true;
+ }
+ else if (parser.LookAhead("render_targets"))
+ {
+ numRenderTargets = parser.ReadInt();
+ continue;
+ }
+ parser.ReadToken();
+ // parse options
+ if (parser.LookAhead("("))
+ {
+ parser.Read("(");
+ while (!parser.IsEnd() && !parser.LookAhead(")"))
{
- parser.Read("=");
- parser.Read("[");
- while (!parser.IsEnd() && !parser.LookAhead("]"))
+ auto word = parser.ReadWord();
+ if (word == "depth")
+ {
+ entry.textureDesc.isDepthTexture = true;
+ }
+ else if (word == "depthCompare")
+ {
+ entry.samplerDesc.isCompareSampler = true;
+ }
+ else if (word == "arrayLength")
{
- if (parser.NextToken().Type == TokenType::IntLiteral)
+ parser.Read("=");
+ entry.textureDesc.arrayLength = parser.ReadInt();
+ }
+ else if (word == "stride")
+ {
+ parser.Read("=");
+ entry.bufferDesc.stride = parser.ReadInt();
+ }
+ else if (word == "size")
+ {
+ parser.Read("=");
+ entry.textureDesc.size = parser.ReadInt();
+ }
+ else if (word == "data")
+ {
+ parser.Read("=");
+ parser.Read("[");
+ while (!parser.IsEnd() && !parser.LookAhead("]"))
{
- entry.bufferData.Add(parser.ReadUInt());
+ if (parser.NextToken().Type == TokenType::IntLiteral)
+ {
+ entry.bufferData.Add(parser.ReadUInt());
+ }
+ else
+ {
+ auto floatNum = parser.ReadFloat();
+ entry.bufferData.Add(*(unsigned int*)&floatNum);
+ }
}
+ parser.Read("]");
+ }
+ else if (word == "content")
+ {
+ parser.Read("=");
+ auto contentWord = parser.ReadWord();
+ if (contentWord == "zero")
+ entry.textureDesc.content = InputTextureContent::Zero;
+ else if (contentWord == "one")
+ entry.textureDesc.content = InputTextureContent::One;
+ else if (contentWord == "chessboard")
+ entry.textureDesc.content = InputTextureContent::ChessBoard;
else
- {
- auto floatNum = parser.ReadFloat();
- entry.bufferData.Add(*(unsigned int*)&floatNum);
- }
+ entry.textureDesc.content = InputTextureContent::Gradient;
}
- parser.Read("]");
- }
- else if (word == "content")
- {
- parser.Read("=");
- auto contentWord = parser.ReadWord();
- if (contentWord == "zero")
- entry.textureDesc.content = InputTextureContent::Zero;
- else if (contentWord == "one")
- entry.textureDesc.content = InputTextureContent::One;
- else if (contentWord == "chessboard")
- entry.textureDesc.content = InputTextureContent::ChessBoard;
+ if (parser.LookAhead(","))
+ parser.Read(",");
else
- entry.textureDesc.content = InputTextureContent::Gradient;
+ break;
}
- if (parser.LookAhead(","))
- parser.Read(",");
- else
- break;
+ parser.Read(")");
}
- parser.Read(")");
- }
- // parse bindings
- if (parser.LookAhead(":"))
- {
- parser.Read(":");
- while (!parser.IsEnd())
+ // parse bindings
+ if (parser.LookAhead(":"))
{
- if (parser.LookAhead("dxbinding"))
- {
- parser.ReadToken();
- parser.Read("(");
- entry.hlslBinding = parser.ReadInt();
- parser.Read(")");
- }
- else if (parser.LookAhead("glbinding"))
+ parser.Read(":");
+ while (!parser.IsEnd())
{
- parser.ReadToken();
- parser.Read("(");
- while (!parser.IsEnd() && !parser.LookAhead(")"))
+ if (parser.LookAhead("dxbinding"))
{
- entry.glslBinding.Add(parser.ReadInt());
- if (parser.LookAhead(","))
- parser.Read(",");
- else
- break;
+ parser.ReadToken();
+ parser.Read("(");
+ entry.hlslBinding = parser.ReadInt();
+ parser.Read(")");
}
- parser.Read(")");
- }
- else if (parser.LookAhead("out"))
- {
- parser.ReadToken();
- entry.isOutput = true;
+ else if (parser.LookAhead("glbinding"))
+ {
+ parser.ReadToken();
+ parser.Read("(");
+ while (!parser.IsEnd() && !parser.LookAhead(")"))
+ {
+ entry.glslBinding.Add(parser.ReadInt());
+ if (parser.LookAhead(","))
+ parser.Read(",");
+ else
+ break;
+ }
+ parser.Read(")");
+ }
+ else if (parser.LookAhead("out"))
+ {
+ parser.ReadToken();
+ entry.isOutput = true;
+ }
+ if (parser.LookAhead(","))
+ parser.Read(",");
}
- if (parser.LookAhead(","))
- parser.Read(",");
}
+ entries.Add(entry);
}
- entries.Add(entry);
}
catch (TextFormatException)
{
diff --git a/tools/render-test/shader-input-layout.h b/tools/render-test/shader-input-layout.h
index 9602e4fe8..c4c3d9d8c 100644
--- a/tools/render-test/shader-input-layout.h
+++ b/tools/render-test/shader-input-layout.h
@@ -63,6 +63,7 @@ namespace renderer_test
{
public:
Slang::List<ShaderInputLayoutEntry> entries;
+ Slang::List<Slang::String> globalTypeArguments;
int numRenderTargets = 1;
void Parse(const char * source);
};
diff --git a/tools/render-test/slang-support.cpp b/tools/render-test/slang-support.cpp
index 63e24126c..746967cb7 100644
--- a/tools/render-test/slang-support.cpp
+++ b/tools/render-test/slang-support.cpp
@@ -84,7 +84,14 @@ struct SlangShaderCompilerWrapper : public ShaderCompiler
ShaderProgram * result = nullptr;
if (request.computeShader.name)
{
- int computeEntryPoint = spAddEntryPoint(slangRequest, computeTranslationUnit, computeEntryPointName, spFindProfile(slangSession, request.computeShader.profile));
+ Slang::List<const char*> rawTypeNames;
+ for (auto typeName : request.entryPointTypeArguments)
+ rawTypeNames.Add(typeName.Buffer());
+ int computeEntryPoint = spAddEntryPointEx(slangRequest, computeTranslationUnit,
+ computeEntryPointName,
+ spFindProfile(slangSession, request.computeShader.profile),
+ (int)rawTypeNames.Count(),
+ rawTypeNames.Buffer());
int compileErr = spCompile(slangRequest);
if (auto diagnostics = spGetDiagnosticOutput(slangRequest))
{