summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorYong He <yonghe@outlook.com>2017-11-22 17:32:15 -0500
committerGitHub <noreply@github.com>2017-11-22 17:32:15 -0500
commit83d49ce376185f7dc3f40eb531f01ee350220959 (patch)
tree7e96f26c6b6e6bf6a8b15ba1820e844e78a31394
parent56e49feea3956d66e41b819c26628c65b3c28197 (diff)
parent581b30dd5a4263c90539a8c5cc6063b2485885cd (diff)
Merge pull request #293 from csyonghe/generic-param-fix
Fixup global generic parameters
-rw-r--r--source/slang/ir.cpp31
-rw-r--r--source/slang/parameter-binding.cpp38
-rw-r--r--source/slang/syntax.cpp47
-rw-r--r--source/slang/syntax.h1
-rw-r--r--source/slang/type-layout.h2
-rw-r--r--tests/compute/global-type-param-in-entrypoint.slang96
-rw-r--r--tests/compute/global-type-param-in-entrypoint.slang.expected.txt4
-rw-r--r--tools/render-test/render-d3d11.cpp5
-rw-r--r--tools/render-test/slang-support.cpp12
9 files changed, 211 insertions, 25 deletions
diff --git a/source/slang/ir.cpp b/source/slang/ir.cpp
index 0f34d5585..598445fcd 100644
--- a/source/slang/ir.cpp
+++ b/source/slang/ir.cpp
@@ -3103,10 +3103,30 @@ namespace Slang
IRGlobalVar* cloneGlobalVar(IRSpecContext* context, IRGlobalVar* originalVar);
IRFunc* cloneFunc(IRSpecContext* context, IRFunc* originalFunc);
IRWitnessTable* cloneWitnessTable(IRSpecContext* context, IRWitnessTable* originalVar);
+ RefPtr<Substitutions> cloneSubstitutions(
+ IRSpecContext* context,
+ Substitutions* subst);
RefPtr<Type> IRSpecContext::maybeCloneType(Type* originalType)
{
- return originalType->Substitute(subst).As<Type>();
+ auto rsType = originalType->GetCanonicalType()->Substitute(subst).As<Type>();
+ if (auto declRefType = rsType.As<DeclRefType>())
+ {
+ if (subst)
+ {
+ auto newSubst = cloneSubstitutions(this, subst);
+ insertSubstAtBottom(declRefType->declRef.substitutions, newSubst);
+ }
+ }
+ else if (auto funcType = rsType.As<FuncType>())
+ {
+ RefPtr<FuncType> newFuncType = new FuncType();
+ newFuncType->setSession(funcType->getSession());
+ newFuncType->resultType = maybeCloneType(funcType->resultType);
+ for (auto paramType : funcType->paramTypes)
+ newFuncType->paramTypes.Add(maybeCloneType(paramType));
+ }
+ return rsType;
}
IRValue* IRSpecContext::maybeCloneValue(IRValue* originalValue)
@@ -3243,6 +3263,15 @@ namespace Slang
newSubst->outer = cloneSubstitutions(context, subst->outer);
return newSubst;
}
+ else if (auto genTypeSubst = dynamic_cast<GlobalGenericParamSubstitution*>(subst))
+ {
+ RefPtr<GlobalGenericParamSubstitution> newSubst = new GlobalGenericParamSubstitution();
+ newSubst->actualType = genTypeSubst->actualType;
+ newSubst->paramDecl = genTypeSubst->paramDecl;
+ newSubst->witnessTables = genTypeSubst->witnessTables;
+ newSubst->outer = cloneSubstitutions(context, subst->outer);
+ return newSubst;
+ }
else
SLANG_UNREACHABLE("unimplemented cloneSubstitution");
UNREACHABLE_RETURN(nullptr);
diff --git a/source/slang/parameter-binding.cpp b/source/slang/parameter-binding.cpp
index 836ed254f..0daa2abc7 100644
--- a/source/slang/parameter-binding.cpp
+++ b/source/slang/parameter-binding.cpp
@@ -1425,6 +1425,16 @@ static RefPtr<TypeLayout> processEntryPointParameter(
return structLayout;
}
+ else if (auto globalGenericParam = declRef.As<GlobalGenericParamDecl>())
+ {
+ auto genParamTypeLayout = new GenericParamTypeLayout();
+ // we should have already populated ProgramLayout::genericEntryPointParams list at this point,
+ // so we can find the index of this generic param decl in the list
+ genParamTypeLayout->type = type;
+ genParamTypeLayout->paramIndex = findGenericParam(context->shared->programLayout->globalGenericParams, globalGenericParam.getDecl());
+ genParamTypeLayout->findOrAddResourceInfo(LayoutResourceKind::GenericResource)->count++;
+ return genParamTypeLayout;
+ }
else
{
SLANG_UNEXPECTED("unhandled type kind");
@@ -1442,7 +1452,8 @@ static RefPtr<TypeLayout> processEntryPointParameter(
static void collectEntryPointParameters(
ParameterBindingContext* context,
- EntryPointRequest* entryPoint)
+ EntryPointRequest* entryPoint,
+ Substitutions* typeSubst)
{
FuncDecl* entryPointFuncDecl = entryPoint->decl;
if (!entryPointFuncDecl)
@@ -1507,7 +1518,7 @@ static void collectEntryPointParameters(
auto paramTypeLayout = processEntryPointParameterWithPossibleSemantic(
context,
paramDecl.Ptr(),
- paramDecl->type.type,
+ paramDecl->type.type->Substitute(typeSubst).As<Type>(),
state,
paramVarLayout);
@@ -1539,7 +1550,7 @@ static void collectEntryPointParameters(
auto resultTypeLayout = processEntryPointParameterWithPossibleSemantic(
context,
entryPointFuncDecl,
- resultType,
+ resultType->Substitute(typeSubst).As<Type>(),
state,
resultLayout);
@@ -1632,7 +1643,7 @@ static void collectParameters(
for( auto& entryPoint : translationUnit->entryPoints )
{
context->stage = entryPoint->profile.GetStage();
- collectEntryPointParameters(context, entryPoint.Ptr());
+ collectEntryPointParameters(context, entryPoint.Ptr(), nullptr);
}
}
@@ -1891,13 +1902,7 @@ RefPtr<ProgramLayout> specializeProgramLayout(
newProgramLayout = new ProgramLayout();
newProgramLayout->bindingForHackSampler = programLayout->bindingForHackSampler;
newProgramLayout->hackSamplerVar = programLayout->hackSamplerVar;
- for (auto & entryPoint : programLayout->entryPoints)
- {
- RefPtr<EntryPointLayout> newEntryPoint = new EntryPointLayout(*entryPoint);
- // TODO: for now just copy existing entry point layouts, but we eventually need to
- // specialize these as well...
- newProgramLayout->entryPoints.Add(newEntryPoint);
- }
+ newProgramLayout->globalGenericParams = programLayout->globalGenericParams;
List<RefPtr<TypeLayout>> paramTypeLayouts;
auto globalStructLayout = getGlobalStructLayout(programLayout);
@@ -1919,7 +1924,7 @@ RefPtr<ProgramLayout> specializeProgramLayout(
SharedParameterBindingContext sharedContext;
sharedContext.compileRequest = targetReq->compileRequest;
sharedContext.defaultLayoutRules = layoutContext.getRulesFamily();
- sharedContext.programLayout = programLayout;
+ sharedContext.programLayout = newProgramLayout;
// Create a sub-context to collect parameters that get
// declared into the global scope
@@ -1928,6 +1933,15 @@ RefPtr<ProgramLayout> specializeProgramLayout(
context.translationUnit = nullptr;
context.layoutContext = layoutContext;
+
+ for (auto & translationUnit : targetReq->compileRequest->translationUnits)
+ {
+ for (auto & entryPoint : translationUnit->entryPoints)
+ {
+ collectEntryPointParameters(&context, entryPoint, typeSubst);
+ }
+ }
+
auto constantBufferRules = context.getRulesFamily()->getConstantBufferRules();
structLayout->rules = constantBufferRules;
diff --git a/source/slang/syntax.cpp b/source/slang/syntax.cpp
index e43dd9074..2c214a332 100644
--- a/source/slang/syntax.cpp
+++ b/source/slang/syntax.cpp
@@ -1309,9 +1309,39 @@ void Type::accept(IValVisitor* visitor, void* extra)
UNREACHABLE_RETURN(expr);
}
+ bool hasGlobalGenericSubst(Substitutions * destSubst, GlobalGenericParamSubstitution * genSubst)
+ {
+ while (destSubst)
+ {
+ if (auto globalParamSubst = dynamic_cast<GlobalGenericParamSubstitution*>(destSubst))
+ {
+ if (globalParamSubst->paramDecl == genSubst->paramDecl)
+ return true;
+ }
+ destSubst = destSubst->outer;
+ }
+ return false;
+ }
+ void insertGlobalGenericSubstitutions(RefPtr<Substitutions> & destSubst, Substitutions * srcSubst)
+ {
+ while (srcSubst)
+ {
+ if (auto globalGenSubst = dynamic_cast<GlobalGenericParamSubstitution*>(srcSubst))
+ {
+ if (!hasGlobalGenericSubst(destSubst, globalGenSubst))
+ {
+ RefPtr<GlobalGenericParamSubstitution> cpyGlobalGenSubst = new GlobalGenericParamSubstitution(*globalGenSubst);
+ cpyGlobalGenSubst->outer = nullptr;
+ insertSubstAtBottom(destSubst, cpyGlobalGenSubst);
+ }
+ }
+ srcSubst = srcSubst->outer;
+ }
+ }
DeclRefBase DeclRefBase::SubstituteImpl(Substitutions* subst, int* ioDiff)
{
+ insertGlobalGenericSubstitutions(substitutions, subst);
if (!substitutions) return *this;
int diff = 0;
@@ -1709,7 +1739,22 @@ void Type::accept(IValVisitor* visitor, void* extra)
return sb.ProduceString();
}
-
+ void insertSubstAtBottom(RefPtr<Substitutions> & substHead, RefPtr<Substitutions> substToInsert)
+ {
+ if (!substHead)
+ {
+ substHead = substToInsert;
+ return;
+ }
+ auto subst = substHead;
+ RefPtr<Substitutions> lastSubst = subst;
+ while (subst->outer)
+ {
+ lastSubst = subst;
+ subst = subst->outer;
+ }
+ lastSubst->outer = substToInsert;
+ }
void insertSubstAtTop(DeclRefBase & declRef, RefPtr<Substitutions> substToInsert)
{
diff --git a/source/slang/syntax.h b/source/slang/syntax.h
index b4d550ef5..f3690d9ae 100644
--- a/source/slang/syntax.h
+++ b/source/slang/syntax.h
@@ -1156,6 +1156,7 @@ namespace Slang
Session* session,
Decl* decl);
+ void insertSubstAtBottom(RefPtr<Substitutions> & substHead, RefPtr<Substitutions> substToInsert);
RefPtr<ThisTypeSubstitution> getNewThisTypeSubst(DeclRefBase & declRef);
RefPtr<ThisTypeSubstitution> getThisTypeSubst(DeclRefBase & declRef, bool insertSubstEntry);
void removeSubstitution(DeclRefBase & declRef, RefPtr<Substitutions> subst);
diff --git a/source/slang/type-layout.h b/source/slang/type-layout.h
index 4ce6dc355..07530bdfc 100644
--- a/source/slang/type-layout.h
+++ b/source/slang/type-layout.h
@@ -676,7 +676,7 @@ createStructuredBufferTypeLayout(
RefPtr<Type> structuredBufferType,
RefPtr<Type> elementType);
-
+int findGenericParam(List<RefPtr<GenericParamLayout>> & genericParameters, GlobalGenericParamDecl * decl);
//
}
diff --git a/tests/compute/global-type-param-in-entrypoint.slang b/tests/compute/global-type-param-in-entrypoint.slang
new file mode 100644
index 000000000..5d8036d98
--- /dev/null
+++ b/tests/compute/global-type-param-in-entrypoint.slang
@@ -0,0 +1,96 @@
+//TEST(compute):COMPARE_RENDER_COMPUTE:-xslang -use-ir
+//TEST_INPUT: cbuffer(data=[1.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 1.0], stride=16):dxbinding(0),glbinding(0)
+//TEST_INPUT: ubuffer(data=[0 0 0 0], stride=4):dxbinding(1),glbinding(0),out
+//TEST_INPUT: type VertImpl
+
+interface IVertInterpolant
+{
+ float4 getColor();
+}
+
+__generic_param TVertInterpolant : IVertInterpolant;
+
+struct VertImpl : IVertInterpolant
+{
+ float3 color;
+ float4 getColor()
+ {
+ return float4(1.0);
+ }
+};
+
+RWStructuredBuffer<float> outputBuffer;
+
+cbuffer Uniforms
+{
+ float4x4 modelViewProjection;
+}
+
+struct AssembledVertex
+{
+ float3 position;
+ TVertInterpolant interpolants;
+ float2 uv;
+};
+
+struct CoarseVertex
+{
+ TVertInterpolant interpolants;
+ float2 uv;
+};
+
+struct Fragment
+{
+ float4 color;
+};
+
+
+// Vertex Shader
+
+struct VertexStageInput
+{
+ AssembledVertex assembledVertex : A;
+};
+
+struct VertexStageOutput
+{
+ CoarseVertex coarseVertex : CoarseVertex;
+ float4 sv_position : SV_Position;
+};
+
+VertexStageOutput vertexMain(VertexStageInput input)
+{
+ VertexStageOutput output;
+
+ float3 position = input.assembledVertex.position;
+ output.coarseVertex.interpolants = input.assembledVertex.interpolants;
+ output.sv_position = mul(modelViewProjection, float4(position, 1.0));
+ output.coarseVertex.uv = input.assembledVertex.uv;
+ return output;
+}
+
+// Fragment Shader
+
+struct FragmentStageInput
+{
+ CoarseVertex coarseVertex : CoarseVertex;
+};
+
+struct FragmentStageOutput
+{
+ Fragment fragment : SV_Target;
+};
+
+FragmentStageOutput fragmentMain(FragmentStageInput input)
+{
+ FragmentStageOutput output;
+
+ float4 color = input.coarseVertex.interpolants.getColor();
+ float2 uv = input.coarseVertex.uv;
+ output.fragment.color = color;
+ outputBuffer[0] = color.x;
+ outputBuffer[1] = color.y;
+ outputBuffer[2] = color.z;
+ outputBuffer[3] = color.w;
+ return output;
+} \ No newline at end of file
diff --git a/tests/compute/global-type-param-in-entrypoint.slang.expected.txt b/tests/compute/global-type-param-in-entrypoint.slang.expected.txt
new file mode 100644
index 000000000..e143b7f20
--- /dev/null
+++ b/tests/compute/global-type-param-in-entrypoint.slang.expected.txt
@@ -0,0 +1,4 @@
+3F800000
+3F800000
+3F800000
+3F800000 \ No newline at end of file
diff --git a/tools/render-test/render-d3d11.cpp b/tools/render-test/render-d3d11.cpp
index cdd6c778e..d0280a770 100644
--- a/tools/render-test/render-d3d11.cpp
+++ b/tools/render-test/render-d3d11.cpp
@@ -457,10 +457,7 @@ public:
UInt RoundUpToAlignment(UInt size, UInt alignment)
{
- if (size % alignment)
- return (size / alignment + 1) * alignment;
- else
- return Math::Max(size, alignment);
+ return ((size + alignment - 1) / alignment) * alignment;
}
virtual Buffer* createBuffer(BufferDesc const& desc) override
diff --git a/tools/render-test/slang-support.cpp b/tools/render-test/slang-support.cpp
index 746967cb7..2465bfd99 100644
--- a/tools/render-test/slang-support.cpp
+++ b/tools/render-test/slang-support.cpp
@@ -82,12 +82,12 @@ struct SlangShaderCompilerWrapper : public ShaderCompiler
spSetCompileFlags(slangRequest, SLANG_COMPILE_FLAG_NO_CHECKING);
}
ShaderProgram * result = nullptr;
+ Slang::List<const char*> rawTypeNames;
+ for (auto typeName : request.entryPointTypeArguments)
+ rawTypeNames.Add(typeName.Buffer());
if (request.computeShader.name)
{
- Slang::List<const char*> rawTypeNames;
- for (auto typeName : request.entryPointTypeArguments)
- rawTypeNames.Add(typeName.Buffer());
- int computeEntryPoint = spAddEntryPointEx(slangRequest, computeTranslationUnit,
+ int computeEntryPoint = spAddEntryPointEx(slangRequest, computeTranslationUnit,
computeEntryPointName,
spFindProfile(slangSession, request.computeShader.profile),
(int)rawTypeNames.Count(),
@@ -107,8 +107,8 @@ struct SlangShaderCompilerWrapper : public ShaderCompiler
}
else
{
- int vertexEntryPoint = spAddEntryPoint(slangRequest, vertexTranslationUnit, vertexEntryPointName, spFindProfile(slangSession, request.vertexShader.profile));
- int fragmentEntryPoint = spAddEntryPoint(slangRequest, fragmentTranslationUnit, fragmentEntryPointName, spFindProfile(slangSession, request.fragmentShader.profile));
+ int vertexEntryPoint = spAddEntryPointEx(slangRequest, vertexTranslationUnit, vertexEntryPointName, spFindProfile(slangSession, request.vertexShader.profile), (int)rawTypeNames.Count(), rawTypeNames.Buffer());
+ int fragmentEntryPoint = spAddEntryPointEx(slangRequest, fragmentTranslationUnit, fragmentEntryPointName, spFindProfile(slangSession, request.fragmentShader.profile), (int)rawTypeNames.Count(), rawTypeNames.Buffer());
int compileErr = spCompile(slangRequest);
if (auto diagnostics = spGetDiagnosticOutput(slangRequest))