From 57f737dc5111b75e2c9591b83eacd2219ea67d07 Mon Sep 17 00:00:00 2001 From: Yong He Date: Mon, 13 Nov 2017 18:22:03 -0500 Subject: Legalization of function parameter types. This commit addresses issue #275 This commit includes following changes: 1. legalize function parameter IRParam instructions 2. legalize function parameter types in IRFuncType 3. legalize call sites (IRCall) with proper arguments 4. legalize local vars that has a mixed resource type. --- source/slang/emit.cpp | 33 +-- source/slang/ir-legalize-types.cpp | 305 ++++++++++++++++++--- source/slang/lower-to-ir.cpp | 1 - tests/compute/func-param-legalize.slang | 35 +++ .../compute/func-param-legalize.slang.expected.txt | 4 + tests/compute/shaderlib.slang | 196 +++++++++++++ 6 files changed, 513 insertions(+), 61 deletions(-) create mode 100644 tests/compute/func-param-legalize.slang create mode 100644 tests/compute/func-param-legalize.slang.expected.txt create mode 100644 tests/compute/shaderlib.slang diff --git a/source/slang/emit.cpp b/source/slang/emit.cpp index a8cd11af4..c72adea38 100644 --- a/source/slang/emit.cpp +++ b/source/slang/emit.cpp @@ -97,8 +97,6 @@ struct SharedEmitContext Dictionary mapIRValueToID; HashSet irDeclsVisited; - - Dictionary irMapContinueTargetToLoopHead; }; struct EmitContext @@ -5566,10 +5564,10 @@ emitDeclImpl(decl, nullptr); emit("for(;;)\n{\n"); - // Register information so that `continue` sites - // can do the right thing: - ctx->shared->irMapContinueTargetToLoopHead.Add(continueBlock, targetBlock); - + // TODO: Okay, we *said* we'd do this special + // handling of the `continue` sites, but + // we aren't actually setting anything up here... + // emitIRStmtsForBlocks( ctx, @@ -5590,28 +5588,7 @@ emitDeclImpl(decl, nullptr); return; case kIROp_continue: - // With out current strategy for outputting loops, - // just outputting an AST-level `continue` here won't - // actually execute the statements in the continue block. - // - // Instead, we have to manually output those statements - // directly here, and *then* do an AST-level `continue`. - // - // This leads to code duplication when we have multiple - // `continue` sites in the original program, but it avoids - // introducing additional temporaries for control flow. - { - auto continueInst = (IRContinue*) terminator; - auto targetBlock = continueInst->getTargetBlock(); - IRBlock* loopHead = nullptr; - ctx->shared->irMapContinueTargetToLoopHead.TryGetValue(targetBlock, loopHead); - SLANG_ASSERT(loopHead); - emitIRStmtsForBlocks( - ctx, - targetBlock, - loopHead); - emit("continue;\n"); - } + emit("continue;\n"); return; case kIROp_loopTest: diff --git a/source/slang/ir-legalize-types.cpp b/source/slang/ir-legalize-types.cpp index 5b08acee8..2719f336b 100644 --- a/source/slang/ir-legalize-types.cpp +++ b/source/slang/ir-legalize-types.cpp @@ -210,7 +210,15 @@ struct TypeLegalizationContext // When inserting new globals, put them before this one. IRGlobalValue* insertBeforeGlobal = nullptr; + // When inserting new parameters, put them before this one. + IRParam* insertBeforeParam = nullptr; + Dictionary mapValToLegalVal; + + IRVar* insertBeforeLocalVar = nullptr; + // store local var instructions that have been replaced here, so we can free them + // when legalization has done + List oldLocalVars; }; static void registerLegalizedValue( @@ -353,6 +361,22 @@ static LegalType legalizeType( return LegalType::simple(type); } +// Represents the "chain" of declarations that +// were followed to get to a variable that we +// are now declaring as a leaf variable. +struct LegalVarChain +{ + LegalVarChain* next; + VarLayout* varLayout; +}; + +static LegalVal declareVars( + TypeLegalizationContext* context, + IROp op, + LegalType type, + TypeLayout* typeLayout, + LegalVarChain* varChain); + // Legalize a type, and then expect it to // result in a simple type. static RefPtr legalizeSimpleType( @@ -388,6 +412,42 @@ static LegalVal legalizeOperand( return LegalVal::simple(irValue); } +static void getArgumentValues( + List & instArgs, + LegalVal val) +{ + switch (val.flavor) + { + case LegalVal::Flavor::simple: + instArgs.Add(val.getSimple()); + break; + case LegalVal::Flavor::implicitDeref: + getArgumentValues(instArgs, val.getImplicitDeref()); + break; + case LegalVal::Flavor::tuple: + { + for (auto elem : val.getTuple()->elements) + getArgumentValues(instArgs, elem.val); + } + break; + } +} + +static LegalVal legalizeCall( + TypeLegalizationContext* context, + IRCall* callInst) +{ + // TODO: implement legalization of non-simple return types + auto retType = legalizeType(context, callInst->type); + SLANG_ASSERT(retType.flavor == LegalType::Flavor::simple); + + List instArgs; + for (auto i = 1u; i < callInst->argCount; i++) + getArgumentValues(instArgs, legalizeOperand(context, callInst->getArg(i))); + + return LegalVal::simple(context->builder->emitCallInst(callInst->type, callInst->func.usedValue, instArgs.Count(), instArgs.Buffer())); +} + static LegalVal legalizeLoad( TypeLegalizationContext* context, LegalVal legalPtrVal) @@ -431,6 +491,48 @@ static LegalVal legalizeLoad( } } +static LegalVal legalizeStore( + TypeLegalizationContext* context, + LegalVal legalPtrVal, + LegalVal legalVal) +{ + switch (legalPtrVal.flavor) + { + case LegalVal::Flavor::simple: + { + context->builder->emitStore(legalPtrVal.getSimple(), legalVal.getSimple()); + return legalVal; + } + break; + + case LegalVal::Flavor::implicitDeref: + // TODO: what is the right behavior here? + if (legalVal.flavor == LegalVal::Flavor::implicitDeref) + return legalizeStore(context, legalPtrVal.getImplicitDeref(), legalVal.getImplicitDeref()); + else + return legalizeStore(context, legalPtrVal.getImplicitDeref(), legalVal); + + case LegalVal::Flavor::tuple: + { + // We need to emit a store for each element of + // the tuple. + auto destTuple = legalPtrVal.getTuple(); + auto valTuple = legalVal.getTuple(); + SLANG_ASSERT(destTuple->elements.Count() == valTuple->elements.Count()); + for (UInt i = 0; i < valTuple->elements.Count(); i++) + { + legalizeStore(context, destTuple->elements[i].val, valTuple->elements[i].val); + } + return legalVal; + } + break; + + default: + SLANG_UNEXPECTED("unhandled case"); + break; + } +} + static LegalVal legalizeFieldAddress( TypeLegalizationContext* context, LegalType type, @@ -492,6 +594,12 @@ static LegalVal legalizeInst( case kIROp_FieldAddress: return legalizeFieldAddress(context, type, args[0], args[1]); + case kIROp_Store: + return legalizeStore(context, args[0], args[1]); + + case kIROp_Call: + return legalizeCall(context, (IRCall*)inst); + default: // TODO: produce a user-visible diagnostic here SLANG_UNEXPECTED("non-simple operand(s)!"); @@ -499,10 +607,74 @@ static LegalVal legalizeInst( } } +RefPtr findVarLayout(IRValue* value) +{ + if (auto layoutDecoration = value->findDecoration()) + return layoutDecoration->layout.As(); + return nullptr; +} + +static LegalVal legalizeLocalVar( + TypeLegalizationContext* context, + IRVar* irLocalVar) +{ + // Legalize the type for the variable's value + auto legalValueType = legalizeType( + context, + irLocalVar->getType()->getValueType()); + + RefPtr varLayout = findVarLayout(irLocalVar); + RefPtr typeLayout = varLayout ? varLayout->typeLayout : nullptr; + + // If we've decided to do implicit deref on the type, + // then go ahead and declare a value of the pointed-to type. + LegalType maybeSimpleType = legalValueType; + while (maybeSimpleType.flavor == LegalType::Flavor::implicitDeref) + { + maybeSimpleType = maybeSimpleType.getImplicitDeref()->valueType; + } + + switch (maybeSimpleType.flavor) + { + case LegalType::Flavor::simple: + // Easy case: the type is usable as-is, and we + // should just do that. + irLocalVar->type = context->session->getPtrType( + maybeSimpleType.getSimple()); + return LegalVal::simple(irLocalVar); + + default: + { + context->insertBeforeLocalVar = irLocalVar; + + LegalVarChain* varChain = nullptr; + LegalVarChain varChainStorage; + if (varLayout) + { + varChainStorage.next = nullptr; + varChainStorage.varLayout = varLayout; + varChain = &varChainStorage; + } + + LegalVal newVal = declareVars(context, kIROp_Var, legalValueType, typeLayout, varChain); + + // Remove the old local var. + irLocalVar->removeFromParent(); + // add old local var to list + context->oldLocalVars.Add(irLocalVar); + return newVal; + } + break; + } +} + static LegalVal legalizeInst( TypeLegalizationContext* context, IRInst* inst) { + if (inst->op == kIROp_Var) + return legalizeLocalVar(context, (IRVar*)inst); + // Need to legalize all the operands. auto argCount = inst->getArgCount(); List legalArgs; @@ -567,40 +739,87 @@ static LegalVal legalizeInst( return legalVal; } +static void addParamType(IRFuncType * ftype, LegalType t) +{ + switch (t.flavor) + { + case LegalType::Flavor::simple: + ftype->paramTypes.Add(t.obj.As()); + break; + case LegalType::Flavor::implicitDeref: + { + auto imp = t.obj.As(); + addParamType(ftype, imp->valueType); + break; + } + case LegalType::Flavor::tuple: + { + auto tup = t.obj.As(); + for (auto & elem : tup->elements) + addParamType(ftype, elem.type); + } + break; + default: + SLANG_ASSERT(false); + } +} + static void legalizeFunc( TypeLegalizationContext* context, IRFunc* irFunc) { // Overwrite the function's type with // the result of legalization. - irFunc->type = legalizeSimpleType(context, irFunc->type); - + auto newFuncType = new IRFuncType(); + newFuncType->setSession(context->session); + auto oldFuncType = irFunc->type.As(); + newFuncType->resultType = legalizeSimpleType(context, oldFuncType->resultType); + for (auto & paramType : oldFuncType->paramTypes) + { + auto legalParamType = legalizeType(context, paramType); + addParamType(newFuncType, legalParamType); + } + irFunc->type = newFuncType; + List paramVals; + List oldParams; + + // we use this list to store replaced local var insts. + // these old instructions will be freed when we are done. + context->oldLocalVars.Clear(); + // Go through the blocks of the function for (auto bb = irFunc->getFirstBlock(); bb; bb = bb->getNextBlock()) { // Legalize the parameters of the block, which may // involve increasing the number of parameters - for (auto pp = bb->getFirstParam(); pp; pp = pp->getNextParam()) + for (auto pp = bb->getFirstParam(); pp; pp = pp->nextParam) { auto legalParamType = legalizeType(context, pp->getType()); - - switch (legalParamType.flavor) + if (legalParamType.flavor != LegalType::Flavor::simple) { - case LegalType::Flavor::simple: - // The type is simple, so we can just rewrite it in place - pp->type = legalParamType.getSimple(); - break; + context->insertBeforeParam = pp; + context->builder->curBlock = nullptr; - default: - // We have something like a tuple, and will need - // to expand into multiple parameters now. - SLANG_UNEXPECTED("need to handle it!"); - break; + auto paramVal = declareVars(context, kIROp_Param, legalParamType, nullptr, nullptr); + paramVals.Add(paramVal); + if (pp == bb->getFirstParam()) + { + bb->firstParam = pp; + while (bb->firstParam->prevParam) + bb->firstParam = bb->firstParam->prevParam; + } + bb->lastParam = pp->prevParam; + if (pp->prevParam) + pp->prevParam->nextParam = pp->nextParam; + if (pp->nextParam) + pp->nextParam->prevParam = pp->prevParam; + auto oldParam = pp; + oldParams.Add(oldParam); + registerLegalizedValue(context, oldParam, paramVal); } - + } - // Now legalize the instructions inside the block IRInst* nextInst = nullptr; for (auto ii = bb->getFirstInst(); ii; ii = nextInst) @@ -611,18 +830,17 @@ static void legalizeFunc( registerLegalizedValue(context, ii, legalVal); } + + } + for (auto & op : oldParams) + { + SLANG_ASSERT(op->firstUse == nullptr || op->firstUse->nextUse == nullptr); + op->deallocate(); } + for (auto & lv : context->oldLocalVars) + lv->deallocate(); } -// Represents the "chain" of declarations that -// were followed to get to a variable that we -// are now declaring as a leaf variable. -struct LegalVarChain -{ - LegalVarChain* next; - VarLayout* varLayout; -}; - static LegalVal declareSimpleVar( TypeLegalizationContext* context, IROp op, @@ -696,7 +914,37 @@ static LegalVal declareSimpleVar( return LegalVal::simple(globalVar); } break; + case kIROp_Var: + { + IRBuilder* builder = context->builder; + + auto localVar = builder->emitVar(type); + localVar->removeFromParent(); + localVar->insertBefore(context->insertBeforeLocalVar); + if (varLayout) + { + builder->addLayoutDecoration(localVar, varLayout); + } + return LegalVal::simple(localVar); + } + break; + case kIROp_Param: + { + IRBuilder* builder = context->builder; + auto param = builder->emitParam(type); + if (context->insertBeforeParam->prevParam) + context->insertBeforeParam->prevParam->nextParam = param; + param->prevParam = context->insertBeforeParam->prevParam; + param->nextParam = context->insertBeforeParam; + context->insertBeforeParam->prevParam = param; + if (varLayout) + { + builder->addLayoutDecoration(param, varLayout); + } + return LegalVal::simple(param); + } + break; default: SLANG_UNEXPECTED("unexpected IR opcode"); break; @@ -808,13 +1056,6 @@ static LegalVal declareVars( } } -RefPtr findVarLayout(IRValue* value) -{ - if (auto layoutDecoration = value->findDecoration()) - return layoutDecoration->layout.As(); - return nullptr; -} - static void legalizeGlobalVar( TypeLegalizationContext* context, IRGlobalVar* irGlobalVar) diff --git a/source/slang/lower-to-ir.cpp b/source/slang/lower-to-ir.cpp index bbbc1812b..759c386cd 100644 --- a/source/slang/lower-to-ir.cpp +++ b/source/slang/lower-to-ir.cpp @@ -2850,7 +2850,6 @@ struct DeclLoweringVisitor : DeclVisitor IRType* irParamType = irResultType; paramTypes.Add(irParamType); subBuilder->emitParam(irParamType); - // TODO: we need some way to wire this up to the `newValue` // or whatever name we give for that parameter inside // the setter body. diff --git a/tests/compute/func-param-legalize.slang b/tests/compute/func-param-legalize.slang new file mode 100644 index 000000000..285fcfbb7 --- /dev/null +++ b/tests/compute/func-param-legalize.slang @@ -0,0 +1,35 @@ +//TEST(compute):COMPARE_COMPUTE:-xslang -use-ir +//TEST_INPUT:Texture2D(size=4, content = one) : dxbinding(0),glbinding(0) +//TEST_INPUT: Sampler : dxbinding(0),glbinding(0,1,2,3,4,5,6) +//TEST_INPUT: ubuffer(data=[0 0 0 0], stride=4):dxbinding(0),glbinding(0),out + +struct Param +{ + Texture2D tex; + SamplerState samplerState; + float base; +}; + +Texture2D diffuseMap; +SamplerState samplerState; +RWStructuredBuffer outputBuffer; + +float4 run(Param p) +{ + return p.tex.SampleLevel(p.samplerState, float2(0.0), 0) + p.base; +} + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + Param p; + p.tex = diffuseMap; + p.samplerState = samplerState; + p.base = -0.5; + float4 outVal = run(p); + + outputBuffer[0] = outVal.x; + outputBuffer[1] = outVal.y; + outputBuffer[2] = outVal.z; + outputBuffer[3] = outVal.w; +} \ No newline at end of file diff --git a/tests/compute/func-param-legalize.slang.expected.txt b/tests/compute/func-param-legalize.slang.expected.txt new file mode 100644 index 000000000..e4e4c642a --- /dev/null +++ b/tests/compute/func-param-legalize.slang.expected.txt @@ -0,0 +1,4 @@ +3F000000 +3F000000 +3F000000 +3F000000 diff --git a/tests/compute/shaderlib.slang b/tests/compute/shaderlib.slang new file mode 100644 index 000000000..fdcce8552 --- /dev/null +++ b/tests/compute/shaderlib.slang @@ -0,0 +1,196 @@ +//TEST(compute):COMPARE_COMPUTE:-xslang -use-ir +//TEST_INPUT:Texture2D(size=4, content = one) : dxbinding(0),glbinding(0) +//TEST_INPUT: Sampler : dxbinding(0),glbinding(0,1,2,3,4,5,6) +//TEST_INPUT: ubuffer(data=[0 0 0 0], stride=4):dxbinding(0),glbinding(0),out +struct SurfacePosition +{ + float3 position; + float3 normal; + float2 uv; + float3 color; +}; +struct LightSample +{ + float3 direction; + float3 intensity; +}; +interface ILight +{ + int getSampleCount(SurfacePosition pos); + LightSample getSample(SurfacePosition pos, int n, int i); +} + +struct View +{ + float3 cameraPos; + float3 cameraDir; + float2 viewportSize; + float4x4 viewTransform, projectionTransform, worldTransform, + viewProjTransform, invViewTransform, invWorldTransform; + SamplerState samplerState; +}; + +interface IBRDFEvaluator +{ + associatedtype SurfaceProperty; + SurfaceProperty evalSurfaceProperty(SurfacePosition pos, View view); + float3 evalLighting(SurfaceProperty surf, LightSample lightSample); + float getOpacity(); +} + +interface IMaterial +{ + associatedtype BRDFEvaluator : IBRDFEvaluator; + BRDFEvaluator evalPattern(View view, SurfacePosition surf); +} + +struct DirectionalLight : ILight +{ + float3 direction; + float3 intensity; + int getSampleCount(SurfacePosition pos) + { + return 1; + } + LightSample getSample(int n, int i) + { + LightSample ls; + ls.direction = direction; + ls.intensity = intensity; + return ls; + } +}; + +float3 EnvBRDFApprox( float3 SpecularColor, float Roughness, float NoV ) +{ + float4 c0 = float4(-1, -0.0275, -0.572, 0.022); + float4 c1 = float4(1, 0.0425, 1.04, -0.04); + float4 r = Roughness * c0 + c1; + float a004 = min( r.x * r.x, exp2( -9.28 * NoV ) ) * r.x + r.y; + float2 AB = float2( -1.04, 1.04 ) * a004 + r.zw; + AB.y *= min(50.0 * SpecularColor.g, 1.0); + return SpecularColor * AB.x + AB.y; +} + +float PhongApprox(float Roughness, float RoL) +{ + float a = Roughness * Roughness; + a = max(a, 0.008); + float a2 = a * a; + float rcp_a2 = 1.0/(a2); + float c = 0.72134752 * rcp_a2 + 0.39674113; + float p = rcp_a2 * exp2(c * RoL - c); + // Total 7 instr + return min(p, rcp_a2); +} + +struct DisneyBRDFSurfaceProperty +{ + float3 diffuseColor; + float3 fspecularColor; + float3 R; +}; + +struct DisneyBRDFEvaluator : IBRDFEvaluator +{ + float3 baseColor; + float3 normal; + float3 emissive; + float roughness, metallic, specular, opacity; + + typedef DisneyBRDFSurfaceProperty SurfaceProperty; + + DisneyBRDFSurfaceProperty evalSurfaceProperty(SurfacePosition pos, View view) + { + DisneyBRDFSurfaceProperty rs; + float3 viewDir = normalize(pos.position - view.cameraPos); + rs.diffuseColor = baseColor * (1.0-metallic); + float dielectricSpecluar = 0.02 * specular; + float3 specularColor = float3(dielectricSpecluar - dielectricSpecluar * metallic) + + baseColor * metallic; + float NoV = max(dot(normal, viewDir), 0.0); + rs.fspecularColor = EnvBRDFApprox(specularColor, roughness, NoV); + rs.R = reflect(-viewDir, normal); + return rs; + } + + float3 evalLighting(DisneyBRDFSurfaceProperty surf, LightSample lightSample) + { + float RoL = max(0, dot(surf.R, lightSample.direction)); + float dotNL = clamp(dot(normal, lightSample.direction), 0.01, 0.99); + float3 color = lightSample.intensity.xyz * dotNL * + (surf.diffuseColor + surf.fspecularColor * PhongApprox(roughness, RoL)) + emissive; + return color; + } + + float getOpacity() + { + return opacity; + } +}; + +struct DisneyMaterial0 : IMaterial +{ + typedef DisneyBRDFEvaluator BRDFEvaluator; + Texture2D diffuseMap; + DisneyBRDFEvaluator evalPattern(View view, SurfacePosition surf) + { + DisneyBRDFEvaluator rs; + rs.baseColor = diffuseMap.Sample(view.samplerState, surf.uv).xyz; + rs.normal = float3(0.0, 1.0, 0.0); + rs.emissive = float3(0.0); + rs.roughness = 1.0; + rs.metallic = 0.0; + rs.specular = 1.0; + rs.opacity = 0.5; + return rs; + } +}; + +__generic +float4 computeShading(View view, SurfacePosition surfPos, TMaterial mat, TLight light) +{ + TMaterial.BRDFEvaluator brdf = mat.evalPattern(view, surfPos); + int lightSampleCount = light.getSampleCount(surfPos); + float3 color = 0; + TMaterial.BRDFEvaluator.SurfaceProperty surfProp = brdf.evalSurfaceProperty(surfPos, view); + for (int i = 0; i < lightSampleCount; i++) + { + color += brdf.evalLighting(surfProp, light.getSample(surfPos, lightSampleCount, i)); + } + return float4(color, brdf.getOpacity()); +} + +Texture2D diffuseMap; +SamplerState samplerState; +RWStructuredBuffer outputBuffer; + +[numthreads(1, 1, 1)] +void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID) +{ + View view; + view.cameraPos = 0.0; + view.cameraDir = float3(0.0, 0.0, -1.0); + view.viewportSize = float2(1920.0, 1080.0); + view.samplerState = samplerState; + + DisneyMaterial0 material; + material.diffuseMap = diffuseMap; + + DirectionalLight dirLight; + dirLight.direction = float3(0.0, 1.0, 0.0); + dirLight.intensity = float3(1.0, 1.0, 1.0); + + SurfacePosition surfPos; + surfPos.position = float3(0.0, -10.0, -10.0); + surfPos.normal = float3(0.0, 1.0, 0.0); + surfPos.uv = float2(0.5, 0.5); + surfPos.color = 0.0; + + float4 outVal = computeShading(view, surfPos, material, dirLight); + + outputBuffer[0] = outVal.x; + outputBuffer[1] = outVal.y; + outputBuffer[2] = outVal.z; + outputBuffer[3] = outVal.w; +} \ No newline at end of file -- cgit v1.2.3