diff options
| author | Yong He <yonghe@outlook.com> | 2017-11-22 17:32:15 -0500 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2017-11-22 17:32:15 -0500 |
| commit | 83d49ce376185f7dc3f40eb531f01ee350220959 (patch) | |
| tree | 7e96f26c6b6e6bf6a8b15ba1820e844e78a31394 | |
| parent | 56e49feea3956d66e41b819c26628c65b3c28197 (diff) | |
| parent | 581b30dd5a4263c90539a8c5cc6063b2485885cd (diff) | |
Merge pull request #293 from csyonghe/generic-param-fix
Fixup global generic parameters
| -rw-r--r-- | source/slang/ir.cpp | 31 | ||||
| -rw-r--r-- | source/slang/parameter-binding.cpp | 38 | ||||
| -rw-r--r-- | source/slang/syntax.cpp | 47 | ||||
| -rw-r--r-- | source/slang/syntax.h | 1 | ||||
| -rw-r--r-- | source/slang/type-layout.h | 2 | ||||
| -rw-r--r-- | tests/compute/global-type-param-in-entrypoint.slang | 96 | ||||
| -rw-r--r-- | tests/compute/global-type-param-in-entrypoint.slang.expected.txt | 4 | ||||
| -rw-r--r-- | tools/render-test/render-d3d11.cpp | 5 | ||||
| -rw-r--r-- | tools/render-test/slang-support.cpp | 12 |
9 files changed, 211 insertions, 25 deletions
diff --git a/source/slang/ir.cpp b/source/slang/ir.cpp index 0f34d5585..598445fcd 100644 --- a/source/slang/ir.cpp +++ b/source/slang/ir.cpp @@ -3103,10 +3103,30 @@ namespace Slang IRGlobalVar* cloneGlobalVar(IRSpecContext* context, IRGlobalVar* originalVar); IRFunc* cloneFunc(IRSpecContext* context, IRFunc* originalFunc); IRWitnessTable* cloneWitnessTable(IRSpecContext* context, IRWitnessTable* originalVar); + RefPtr<Substitutions> cloneSubstitutions( + IRSpecContext* context, + Substitutions* subst); RefPtr<Type> IRSpecContext::maybeCloneType(Type* originalType) { - return originalType->Substitute(subst).As<Type>(); + auto rsType = originalType->GetCanonicalType()->Substitute(subst).As<Type>(); + if (auto declRefType = rsType.As<DeclRefType>()) + { + if (subst) + { + auto newSubst = cloneSubstitutions(this, subst); + insertSubstAtBottom(declRefType->declRef.substitutions, newSubst); + } + } + else if (auto funcType = rsType.As<FuncType>()) + { + RefPtr<FuncType> newFuncType = new FuncType(); + newFuncType->setSession(funcType->getSession()); + newFuncType->resultType = maybeCloneType(funcType->resultType); + for (auto paramType : funcType->paramTypes) + newFuncType->paramTypes.Add(maybeCloneType(paramType)); + } + return rsType; } IRValue* IRSpecContext::maybeCloneValue(IRValue* originalValue) @@ -3243,6 +3263,15 @@ namespace Slang newSubst->outer = cloneSubstitutions(context, subst->outer); return newSubst; } + else if (auto genTypeSubst = dynamic_cast<GlobalGenericParamSubstitution*>(subst)) + { + RefPtr<GlobalGenericParamSubstitution> newSubst = new GlobalGenericParamSubstitution(); + newSubst->actualType = genTypeSubst->actualType; + newSubst->paramDecl = genTypeSubst->paramDecl; + newSubst->witnessTables = genTypeSubst->witnessTables; + newSubst->outer = cloneSubstitutions(context, subst->outer); + return newSubst; + } else SLANG_UNREACHABLE("unimplemented cloneSubstitution"); UNREACHABLE_RETURN(nullptr); diff --git a/source/slang/parameter-binding.cpp b/source/slang/parameter-binding.cpp index 836ed254f..0daa2abc7 100644 --- a/source/slang/parameter-binding.cpp +++ b/source/slang/parameter-binding.cpp @@ -1425,6 +1425,16 @@ static RefPtr<TypeLayout> processEntryPointParameter( return structLayout; } + else if (auto globalGenericParam = declRef.As<GlobalGenericParamDecl>()) + { + auto genParamTypeLayout = new GenericParamTypeLayout(); + // we should have already populated ProgramLayout::genericEntryPointParams list at this point, + // so we can find the index of this generic param decl in the list + genParamTypeLayout->type = type; + genParamTypeLayout->paramIndex = findGenericParam(context->shared->programLayout->globalGenericParams, globalGenericParam.getDecl()); + genParamTypeLayout->findOrAddResourceInfo(LayoutResourceKind::GenericResource)->count++; + return genParamTypeLayout; + } else { SLANG_UNEXPECTED("unhandled type kind"); @@ -1442,7 +1452,8 @@ static RefPtr<TypeLayout> processEntryPointParameter( static void collectEntryPointParameters( ParameterBindingContext* context, - EntryPointRequest* entryPoint) + EntryPointRequest* entryPoint, + Substitutions* typeSubst) { FuncDecl* entryPointFuncDecl = entryPoint->decl; if (!entryPointFuncDecl) @@ -1507,7 +1518,7 @@ static void collectEntryPointParameters( auto paramTypeLayout = processEntryPointParameterWithPossibleSemantic( context, paramDecl.Ptr(), - paramDecl->type.type, + paramDecl->type.type->Substitute(typeSubst).As<Type>(), state, paramVarLayout); @@ -1539,7 +1550,7 @@ static void collectEntryPointParameters( auto resultTypeLayout = processEntryPointParameterWithPossibleSemantic( context, entryPointFuncDecl, - resultType, + resultType->Substitute(typeSubst).As<Type>(), state, resultLayout); @@ -1632,7 +1643,7 @@ static void collectParameters( for( auto& entryPoint : translationUnit->entryPoints ) { context->stage = entryPoint->profile.GetStage(); - collectEntryPointParameters(context, entryPoint.Ptr()); + collectEntryPointParameters(context, entryPoint.Ptr(), nullptr); } } @@ -1891,13 +1902,7 @@ RefPtr<ProgramLayout> specializeProgramLayout( newProgramLayout = new ProgramLayout(); newProgramLayout->bindingForHackSampler = programLayout->bindingForHackSampler; newProgramLayout->hackSamplerVar = programLayout->hackSamplerVar; - for (auto & entryPoint : programLayout->entryPoints) - { - RefPtr<EntryPointLayout> newEntryPoint = new EntryPointLayout(*entryPoint); - // TODO: for now just copy existing entry point layouts, but we eventually need to - // specialize these as well... - newProgramLayout->entryPoints.Add(newEntryPoint); - } + newProgramLayout->globalGenericParams = programLayout->globalGenericParams; List<RefPtr<TypeLayout>> paramTypeLayouts; auto globalStructLayout = getGlobalStructLayout(programLayout); @@ -1919,7 +1924,7 @@ RefPtr<ProgramLayout> specializeProgramLayout( SharedParameterBindingContext sharedContext; sharedContext.compileRequest = targetReq->compileRequest; sharedContext.defaultLayoutRules = layoutContext.getRulesFamily(); - sharedContext.programLayout = programLayout; + sharedContext.programLayout = newProgramLayout; // Create a sub-context to collect parameters that get // declared into the global scope @@ -1928,6 +1933,15 @@ RefPtr<ProgramLayout> specializeProgramLayout( context.translationUnit = nullptr; context.layoutContext = layoutContext; + + for (auto & translationUnit : targetReq->compileRequest->translationUnits) + { + for (auto & entryPoint : translationUnit->entryPoints) + { + collectEntryPointParameters(&context, entryPoint, typeSubst); + } + } + auto constantBufferRules = context.getRulesFamily()->getConstantBufferRules(); structLayout->rules = constantBufferRules; diff --git a/source/slang/syntax.cpp b/source/slang/syntax.cpp index e43dd9074..2c214a332 100644 --- a/source/slang/syntax.cpp +++ b/source/slang/syntax.cpp @@ -1309,9 +1309,39 @@ void Type::accept(IValVisitor* visitor, void* extra) UNREACHABLE_RETURN(expr); } + bool hasGlobalGenericSubst(Substitutions * destSubst, GlobalGenericParamSubstitution * genSubst) + { + while (destSubst) + { + if (auto globalParamSubst = dynamic_cast<GlobalGenericParamSubstitution*>(destSubst)) + { + if (globalParamSubst->paramDecl == genSubst->paramDecl) + return true; + } + destSubst = destSubst->outer; + } + return false; + } + void insertGlobalGenericSubstitutions(RefPtr<Substitutions> & destSubst, Substitutions * srcSubst) + { + while (srcSubst) + { + if (auto globalGenSubst = dynamic_cast<GlobalGenericParamSubstitution*>(srcSubst)) + { + if (!hasGlobalGenericSubst(destSubst, globalGenSubst)) + { + RefPtr<GlobalGenericParamSubstitution> cpyGlobalGenSubst = new GlobalGenericParamSubstitution(*globalGenSubst); + cpyGlobalGenSubst->outer = nullptr; + insertSubstAtBottom(destSubst, cpyGlobalGenSubst); + } + } + srcSubst = srcSubst->outer; + } + } DeclRefBase DeclRefBase::SubstituteImpl(Substitutions* subst, int* ioDiff) { + insertGlobalGenericSubstitutions(substitutions, subst); if (!substitutions) return *this; int diff = 0; @@ -1709,7 +1739,22 @@ void Type::accept(IValVisitor* visitor, void* extra) return sb.ProduceString(); } - + void insertSubstAtBottom(RefPtr<Substitutions> & substHead, RefPtr<Substitutions> substToInsert) + { + if (!substHead) + { + substHead = substToInsert; + return; + } + auto subst = substHead; + RefPtr<Substitutions> lastSubst = subst; + while (subst->outer) + { + lastSubst = subst; + subst = subst->outer; + } + lastSubst->outer = substToInsert; + } void insertSubstAtTop(DeclRefBase & declRef, RefPtr<Substitutions> substToInsert) { diff --git a/source/slang/syntax.h b/source/slang/syntax.h index b4d550ef5..f3690d9ae 100644 --- a/source/slang/syntax.h +++ b/source/slang/syntax.h @@ -1156,6 +1156,7 @@ namespace Slang Session* session, Decl* decl); + void insertSubstAtBottom(RefPtr<Substitutions> & substHead, RefPtr<Substitutions> substToInsert); RefPtr<ThisTypeSubstitution> getNewThisTypeSubst(DeclRefBase & declRef); RefPtr<ThisTypeSubstitution> getThisTypeSubst(DeclRefBase & declRef, bool insertSubstEntry); void removeSubstitution(DeclRefBase & declRef, RefPtr<Substitutions> subst); diff --git a/source/slang/type-layout.h b/source/slang/type-layout.h index 4ce6dc355..07530bdfc 100644 --- a/source/slang/type-layout.h +++ b/source/slang/type-layout.h @@ -676,7 +676,7 @@ createStructuredBufferTypeLayout( RefPtr<Type> structuredBufferType, RefPtr<Type> elementType); - +int findGenericParam(List<RefPtr<GenericParamLayout>> & genericParameters, GlobalGenericParamDecl * decl); // } diff --git a/tests/compute/global-type-param-in-entrypoint.slang b/tests/compute/global-type-param-in-entrypoint.slang new file mode 100644 index 000000000..5d8036d98 --- /dev/null +++ b/tests/compute/global-type-param-in-entrypoint.slang @@ -0,0 +1,96 @@ +//TEST(compute):COMPARE_RENDER_COMPUTE:-xslang -use-ir +//TEST_INPUT: cbuffer(data=[1.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 1.0], stride=16):dxbinding(0),glbinding(0) +//TEST_INPUT: ubuffer(data=[0 0 0 0], stride=4):dxbinding(1),glbinding(0),out +//TEST_INPUT: type VertImpl + +interface IVertInterpolant +{ + float4 getColor(); +} + +__generic_param TVertInterpolant : IVertInterpolant; + +struct VertImpl : IVertInterpolant +{ + float3 color; + float4 getColor() + { + return float4(1.0); + } +}; + +RWStructuredBuffer<float> outputBuffer; + +cbuffer Uniforms +{ + float4x4 modelViewProjection; +} + +struct AssembledVertex +{ + float3 position; + TVertInterpolant interpolants; + float2 uv; +}; + +struct CoarseVertex +{ + TVertInterpolant interpolants; + float2 uv; +}; + +struct Fragment +{ + float4 color; +}; + + +// Vertex Shader + +struct VertexStageInput +{ + AssembledVertex assembledVertex : A; +}; + +struct VertexStageOutput +{ + CoarseVertex coarseVertex : CoarseVertex; + float4 sv_position : SV_Position; +}; + +VertexStageOutput vertexMain(VertexStageInput input) +{ + VertexStageOutput output; + + float3 position = input.assembledVertex.position; + output.coarseVertex.interpolants = input.assembledVertex.interpolants; + output.sv_position = mul(modelViewProjection, float4(position, 1.0)); + output.coarseVertex.uv = input.assembledVertex.uv; + return output; +} + +// Fragment Shader + +struct FragmentStageInput +{ + CoarseVertex coarseVertex : CoarseVertex; +}; + +struct FragmentStageOutput +{ + Fragment fragment : SV_Target; +}; + +FragmentStageOutput fragmentMain(FragmentStageInput input) +{ + FragmentStageOutput output; + + float4 color = input.coarseVertex.interpolants.getColor(); + float2 uv = input.coarseVertex.uv; + output.fragment.color = color; + outputBuffer[0] = color.x; + outputBuffer[1] = color.y; + outputBuffer[2] = color.z; + outputBuffer[3] = color.w; + return output; +}
\ No newline at end of file diff --git a/tests/compute/global-type-param-in-entrypoint.slang.expected.txt b/tests/compute/global-type-param-in-entrypoint.slang.expected.txt new file mode 100644 index 000000000..e143b7f20 --- /dev/null +++ b/tests/compute/global-type-param-in-entrypoint.slang.expected.txt @@ -0,0 +1,4 @@ +3F800000 +3F800000 +3F800000 +3F800000
\ No newline at end of file diff --git a/tools/render-test/render-d3d11.cpp b/tools/render-test/render-d3d11.cpp index cdd6c778e..d0280a770 100644 --- a/tools/render-test/render-d3d11.cpp +++ b/tools/render-test/render-d3d11.cpp @@ -457,10 +457,7 @@ public: UInt RoundUpToAlignment(UInt size, UInt alignment) { - if (size % alignment) - return (size / alignment + 1) * alignment; - else - return Math::Max(size, alignment); + return ((size + alignment - 1) / alignment) * alignment; } virtual Buffer* createBuffer(BufferDesc const& desc) override diff --git a/tools/render-test/slang-support.cpp b/tools/render-test/slang-support.cpp index 746967cb7..2465bfd99 100644 --- a/tools/render-test/slang-support.cpp +++ b/tools/render-test/slang-support.cpp @@ -82,12 +82,12 @@ struct SlangShaderCompilerWrapper : public ShaderCompiler spSetCompileFlags(slangRequest, SLANG_COMPILE_FLAG_NO_CHECKING); } ShaderProgram * result = nullptr; + Slang::List<const char*> rawTypeNames; + for (auto typeName : request.entryPointTypeArguments) + rawTypeNames.Add(typeName.Buffer()); if (request.computeShader.name) { - Slang::List<const char*> rawTypeNames; - for (auto typeName : request.entryPointTypeArguments) - rawTypeNames.Add(typeName.Buffer()); - int computeEntryPoint = spAddEntryPointEx(slangRequest, computeTranslationUnit, + int computeEntryPoint = spAddEntryPointEx(slangRequest, computeTranslationUnit, computeEntryPointName, spFindProfile(slangSession, request.computeShader.profile), (int)rawTypeNames.Count(), @@ -107,8 +107,8 @@ struct SlangShaderCompilerWrapper : public ShaderCompiler } else { - int vertexEntryPoint = spAddEntryPoint(slangRequest, vertexTranslationUnit, vertexEntryPointName, spFindProfile(slangSession, request.vertexShader.profile)); - int fragmentEntryPoint = spAddEntryPoint(slangRequest, fragmentTranslationUnit, fragmentEntryPointName, spFindProfile(slangSession, request.fragmentShader.profile)); + int vertexEntryPoint = spAddEntryPointEx(slangRequest, vertexTranslationUnit, vertexEntryPointName, spFindProfile(slangSession, request.vertexShader.profile), (int)rawTypeNames.Count(), rawTypeNames.Buffer()); + int fragmentEntryPoint = spAddEntryPointEx(slangRequest, fragmentTranslationUnit, fragmentEntryPointName, spFindProfile(slangSession, request.fragmentShader.profile), (int)rawTypeNames.Count(), rawTypeNames.Buffer()); int compileErr = spCompile(slangRequest); if (auto diagnostics = spGetDiagnosticOutput(slangRequest)) |
