diff options
| author | Tim Foley <tfoleyNV@users.noreply.github.com> | 2017-09-14 15:37:05 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2017-09-14 15:37:05 -0700 |
| commit | 10b62eecd94be53eca4ac2555af860f864966d76 (patch) | |
| tree | 9a140acfda0e3f0755f2c120870c72d5a8f4b232 /source | |
| parent | 8cdfce564546c03c2c1ce179561591276aeb23a8 (diff) | |
IR: handle control flow constructs (#186)
* IR: handle control flow constructs
This change includes a bunch of fixes and additions to the IR path:
- `slang-ir-assembly` is now a valid output target (so we can use it for testing)
- This uses what used to be the IR "dumping" logic, revamped to support much prettier output.
- A future change will need to add back support for less prettified output to use when actually debugging
- IR generation for `for` loops and `if` statements is supported
- HLSL output from the above control flow constructs is implemented
- Revamped the handling of l-values, and in particular work on compound ops like `+=`
- Add basic IR support for `groupshared` variables
- Add basic IR support for storing compute thread-group size
- Output semantics on entry point parameters
- This uses the AST structures to find semantics, so its still needs work
- Pass through loop unroll flags
- This is required to match `fxc` output, at least until we implement
unrolling ourselves.
* Fixup: 64-bit build issues.
* fixup for merge
Diffstat (limited to 'source')
| -rw-r--r-- | source/core/slang-string.cpp | 14 | ||||
| -rw-r--r-- | source/core/slang-string.h | 7 | ||||
| -rw-r--r-- | source/slang/compiler.cpp | 52 | ||||
| -rw-r--r-- | source/slang/compiler.h | 2 | ||||
| -rw-r--r-- | source/slang/emit.cpp | 438 | ||||
| -rw-r--r-- | source/slang/hlsl.meta.slang | 8 | ||||
| -rw-r--r-- | source/slang/hlsl.meta.slang.cpp | 9 | ||||
| -rw-r--r-- | source/slang/ir-inst-defs.h | 82 | ||||
| -rw-r--r-- | source/slang/ir-insts.h | 569 | ||||
| -rw-r--r-- | source/slang/ir.cpp | 644 | ||||
| -rw-r--r-- | source/slang/ir.h | 326 | ||||
| -rw-r--r-- | source/slang/lower-to-ir.cpp | 1211 | ||||
| -rw-r--r-- | source/slang/options.cpp | 2 | ||||
| -rw-r--r-- | source/slang/profile-defs.h | 2 | ||||
| -rw-r--r-- | source/slang/slang.vcxproj | 1 | ||||
| -rw-r--r-- | source/slang/slang.vcxproj.filters | 7 | ||||
| -rw-r--r-- | source/slang/type-layout.cpp | 2 |
17 files changed, 2767 insertions, 609 deletions
diff --git a/source/core/slang-string.cpp b/source/core/slang-string.cpp index 9bc9e3a54..5a3d8e4f9 100644 --- a/source/core/slang-string.cpp +++ b/source/core/slang-string.cpp @@ -266,7 +266,7 @@ namespace Slang append(slice.begin(), slice.end()); } - void String::append(int value, int radix) + void String::append(int32_t value, int radix) { enum { kCount = 33 }; char* data = prepareForAppend(kCount); @@ -275,7 +275,7 @@ namespace Slang buffer->length += count; } - void String::append(unsigned int value, int radix) + void String::append(uint32_t value, int radix) { enum { kCount = 33 }; char* data = prepareForAppend(kCount); @@ -284,7 +284,7 @@ namespace Slang buffer->length += count; } - void String::append(long long value, int radix) + void String::append(int64_t value, int radix) { enum { kCount = 65 }; char* data = prepareForAppend(kCount); @@ -293,6 +293,14 @@ namespace Slang buffer->length += count; } + void String::append(uint64_t value, int radix) + { + enum { kCount = 65 }; + char* data = prepareForAppend(kCount); + auto count = IntToAscii(data, value, radix); + ReverseInternalAscii(data, count); + buffer->length += count; + } void String::append(float val, const char * format) { diff --git a/source/core/slang-string.h b/source/core/slang-string.h index 98fd51a07..ed333f8e8 100644 --- a/source/core/slang-string.h +++ b/source/core/slang-string.h @@ -257,9 +257,10 @@ namespace Slang return getData() + getLength(); } - void append(int value, int radix = 10); - void append(unsigned int value, int radix = 10); - void append(long long value, int radix = 10); + void append(int32_t value, int radix = 10); + void append(uint32_t value, int radix = 10); + void append(int64_t value, int radix = 10); + void append(uint64_t value, int radix = 10); void append(float val, const char * format = "%g"); void append(double val, const char * format = "%g"); diff --git a/source/slang/compiler.cpp b/source/slang/compiler.cpp index 9ca1d247f..25ae460c7 100644 --- a/source/slang/compiler.cpp +++ b/source/slang/compiler.cpp @@ -499,25 +499,14 @@ namespace Slang } #endif -#if 0 - String emitSPIRVAssembly( - ExtraContext& context) + List<uint8_t> emitSlangIRForEntryPoint( + EntryPointRequest* entryPoint) { - if(context.getTranslationUnitOptions().entryPoints.Count() == 0) - { - // TODO(tfoley): need to write diagnostics into this whole thing... - fprintf(stderr, "no entry point specified\n"); - return ""; - } - - StringBuilder sb; - for (auto entryPoint : context.getTranslationUnitOptions().entryPoints) - { - sb << emitSPIRVAssemblyForEntryPoint(context, entryPoint); - } - return sb.ProduceString(); + SLANG_UNIMPLEMENTED_X("Slang IR Binary Generation"); } -#endif + + String emitSlangIRAssemblyForEntryPoint( + EntryPointRequest* entryPoint); // Do emit logic for a single entry point CompileResult emitEntryPoint( @@ -578,6 +567,22 @@ namespace Slang } break; + case CodeGenTarget::SlangIR: + { + List<uint8_t> code = emitSlangIRForEntryPoint(entryPoint); + maybeDumpIntermediate(compileRequest, code.Buffer(), code.Count(), target); + result = CompileResult(code); + } + break; + + case CodeGenTarget::SlangIRAssembly: + { + String code = emitSlangIRAssemblyForEntryPoint(entryPoint); + maybeDumpIntermediate(compileRequest, code.Buffer(), target); + result = CompileResult(code); + } + break; + case CodeGenTarget::None: // The user requested no output break; @@ -957,6 +962,10 @@ namespace Slang dumpIntermediateText(compileRequest, data, size, ".dxbc.asm"); break; + case CodeGenTarget::SlangIRAssembly: + dumpIntermediateText(compileRequest, data, size, ".slang-ir.asm"); + break; + case CodeGenTarget::SPIRV: dumpIntermediateBinary(compileRequest, data, size, ".spv"); { @@ -972,6 +981,15 @@ namespace Slang dumpIntermediateText(compileRequest, dxbcAssembly.begin(), dxbcAssembly.Length(), ".dxbc.asm"); } break; + + case CodeGenTarget::SlangIR: + dumpIntermediateBinary(compileRequest, data, size, ".slang-ir"); + { + // TODO: need to support dissassembly from Slang IR binary +// String slangIRAssembly = dissassembleSlangIR(compileRequest, data, size); +// dumpIntermediateText(compileRequest, slangIRAssembly.begin(), slangIRAssembly.Length(), ".slang-ir.asm"); + } + break; } } diff --git a/source/slang/compiler.h b/source/slang/compiler.h index b6eca1b36..7ec812d63 100644 --- a/source/slang/compiler.h +++ b/source/slang/compiler.h @@ -47,6 +47,8 @@ namespace Slang DXBytecode = SLANG_DXBC, DXBytecodeAssembly = SLANG_DXBC_ASM, ReflectionJSON = SLANG_REFLECTION_JSON, + SlangIR = SLANG_IR, + SlangIRAssembly = SLANG_IR_ASM, }; enum class LineDirectiveMode : SlangLineDirectiveMode diff --git a/source/slang/emit.cpp b/source/slang/emit.cpp index 209436da4..25a0b5323 100644 --- a/source/slang/emit.cpp +++ b/source/slang/emit.cpp @@ -1,6 +1,7 @@ // emit.cpp #include "emit.h" +#include "ir-insts.h" #include "lower.h" #include "lower-to-ir.h" #include "name.h" @@ -3966,11 +3967,16 @@ emitDeclImpl(decl, nullptr); { Simple, Ptr, + Array, }; Flavor flavor; IRDeclaratorInfo* next; - String const* name; + union + { + String const* name; + IRInst* elementCount; + }; }; void emitDeclarator( @@ -3991,6 +3997,13 @@ emitDeclImpl(decl, nullptr); emit("*"); emitDeclarator(context, declarator->next); break; + + case IRDeclaratorInfo::Flavor::Array: + emitDeclarator(context, declarator->next); + emit("["); + emitIROperand(context, declarator->elementCount); + emit("]"); + break; } } @@ -4008,6 +4021,13 @@ emitDeclImpl(decl, nullptr); emit(((IRConstant*) inst)->u.floatVal); break; + case kIROp_boolConst: + { + bool val = ((IRConstant*)inst)->u.intVal != 0; + emit(val ? "true" : "false"); + } + break; + default: SLANG_UNIMPLEMENTED_X("val case for emit"); break; @@ -4028,6 +4048,8 @@ emitDeclImpl(decl, nullptr); CASE(Float32Type, float); CASE(Int32Type, int); CASE(UInt32Type, uint); + CASE(VoidType, void); + CASE(BoolType, bool); #undef CASE @@ -4084,6 +4106,15 @@ emitDeclImpl(decl, nullptr); } break; + case kIROp_structuredBufferType: + { + auto tt = (IRBufferType*) type; + emit("StructuredBuffer<"); + emitIRType(context, tt->getElementType(), nullptr); + emit(">"); + } + break; + case kIROp_SamplerType: { // TODO: actually look at the flavor and emit the right name @@ -4127,11 +4158,11 @@ emitDeclImpl(decl, nullptr); IRType* type, IRDeclaratorInfo* declarator) { - switch( type->op ) + switch (type->op) { case kIROp_PtrType: { - auto ptrType = (IRPtrType*) type; + auto ptrType = (IRPtrType*)type; IRDeclaratorInfo ptrDeclarator; ptrDeclarator.flavor = IRDeclaratorInfo::Flavor::Ptr; @@ -4140,6 +4171,18 @@ emitDeclImpl(decl, nullptr); } break; + case kIROp_arrayType: + { + auto arrayType = (IRArrayType*)type; + + IRDeclaratorInfo arrayDeclarator; + arrayDeclarator.flavor = IRDeclaratorInfo::Flavor::Array; + arrayDeclarator.elementCount = arrayType->getElementCount(); + arrayDeclarator.next = declarator; + emitIRType(context, arrayType->getElementType(), &arrayDeclarator); + } + break; + default: emitIRSimpleType(context, type); emitDeclarator(context, declarator); @@ -4178,6 +4221,7 @@ emitDeclImpl(decl, nullptr); case kIROp_IntLit: case kIROp_FloatLit: + case kIROp_boolConst: case kIROp_FieldAddress: case kIROp_getElementPtr: return true; @@ -4280,17 +4324,9 @@ emitDeclImpl(decl, nullptr); switch(inst->op) { case kIROp_IntLit: - { - auto irConst = (IRConstant*) inst; - emit(irConst->u.intVal); - } - break; - case kIROp_FloatLit: - { - auto irConst = (IRConstant*) inst; - emit(irConst->u.floatVal); - } + case kIROp_boolConst: + emitIRSimpleValue(context, inst); break; case kIROp_Construct: @@ -4351,8 +4387,39 @@ emitDeclImpl(decl, nullptr); CASE(kIROp_Div, /); CASE(kIROp_Mod, %); + CASE(kIROp_Lsh, <<); + CASE(kIROp_Rsh, >>); + + CASE(kIROp_Eql, ==); + CASE(kIROp_Neq, !=); + CASE(kIROp_Greater, >); + CASE(kIROp_Less, <); + CASE(kIROp_Geq, >=); + CASE(kIROp_Leq, <=); + + CASE(kIROp_BitAnd, &); + CASE(kIROp_BitXor, ^); + CASE(kIROp_BitOr, |); + + CASE(kIROp_And, &&); + CASE(kIROp_Or, ||); + #undef CASE + case kIROp_Not: + { + if (inst->getType()->op == kIROp_BoolType) + { + emit("!"); + } + else + { + emit("~"); + } + emitIROperand(context, inst->getArg(1)); + } + break; + case kIROp_Sample: emitIROperand(context, inst->getArg(1)); emit(".Sample("); @@ -4408,6 +4475,18 @@ emitDeclImpl(decl, nullptr); emit("]"); break; + case kIROp_BufferStore: + emitIROperand(context, inst->getArg(1)); + emit("["); + emitIROperand(context, inst->getArg(2)); + emit("] = "); + emitIROperand(context, inst->getArg(3)); + break; + + case kIROp_GroupMemoryBarrierWithGroupSync: + emit("GroupMemoryBarrierWithGroupSync()"); + break; + case kIROp_getElement: case kIROp_getElementPtr: emitIROperand(context, inst->getArg(1)); @@ -4426,8 +4505,29 @@ emitDeclImpl(decl, nullptr); emit(")"); break; + case kIROp_swizzle: + { + auto ii = (IRSwizzle*)inst; + emitIROperand(context, ii->getBase()); + emit("."); + UInt elementCount = ii->getElementCount(); + for (UInt ee = 0; ee < elementCount; ++ee) + { + IRInst* irElementIndex = ii->getElementIndex(ee); + assert(irElementIndex->op == kIROp_IntLit); + IRConstant* irConst = (IRConstant*)irElementIndex; + + UInt elementIndex = (UInt)irConst->u.intVal; + assert(elementIndex < 4); + + char const* kComponents[] = { "x", "y", "z", "w" }; + emit(kComponents[elementIndex]); + } + } + break; + default: - emit("/* uhandled */"); + emit("/* unhandled */"); break; } } @@ -4478,6 +4578,51 @@ emitDeclImpl(decl, nullptr); emitIROperand(context, ((IRReturnVal*) inst)->getVal()); emit(";\n"); break; + + case kIROp_swizzleSet: + { + auto ii = (IRSwizzleSet*)inst; + emitIRInstResultDecl(context, inst); + emitIROperand(context, inst->getArg(1)); + emit(";\n"); + emitIROperand(context, inst); + emit("."); + UInt elementCount = ii->getElementCount(); + for (UInt ee = 0; ee < elementCount; ++ee) + { + IRInst* irElementIndex = ii->getElementIndex(ee); + assert(irElementIndex->op == kIROp_IntLit); + IRConstant* irConst = (IRConstant*)irElementIndex; + + UInt elementIndex = (UInt)irConst->u.intVal; + assert(elementIndex < 4); + + char const* kComponents[] = { "x", "y", "z", "w" }; + emit(kComponents[elementIndex]); + } + emit(" = "); + emitIROperand(context, inst->getArg(2)); + emit(";\n"); + } + break; + +#if 0 + case kIROp_unconditionalBranch: + emit("// unconditionalBranch "); + emitIROperand(context, inst->getArg(1)); + emit("\n"); + break; + + case kIROp_conditionalBranch: + emit("// conditionalBranch "); + emitIROperand(context, inst->getArg(1)); + emit(" "); + emitIROperand(context, inst->getArg(2)); + emit(" "); + emitIROperand(context, inst->getArg(3)); + emit("\n"); + break; +#endif } } @@ -4515,6 +4660,238 @@ emitDeclImpl(decl, nullptr); } } + // We want to emit a range of code in the IR, represented + // by the blocks that are logically in the interval [begin, end) + // which we consider as a single-entry multiple-exit region. + // + // Note: because there are multiple exists, control flow + // may exit this region with operations that do *not* branch + // to `end`, but such non-local control flow will hopefully + // be captured. + // + void emitIRStmtsForBlocks( + EmitContext* context, + IRBlock* begin, + IRBlock* end) + { + IRBlock* block = begin; + while(block != end) + { + // Start by emitting the non-terminator instructions in the block. + auto terminator = block->lastChild; + assert(isTerminatorInst(terminator)); + for (auto inst = block->firstChild; inst != terminator; inst = inst->nextInst) + { + emitIRInst(context, inst); + } + + // Now look at the terminator instruction, which will tell us what we need to emit next. + + switch (terminator->op) + { + default: + SLANG_UNEXPECTED("terminator inst"); + return; + + case kIROp_ReturnVal: + case kIROp_ReturnVoid: + emitIRInst(context, terminator); + return; + + case kIROp_if: + { + // One-sided `if` statement + auto t = (IRIf*)terminator; + + auto trueBlock = t->getTrueBlock(); + auto afterBlock = t->getAfterBlock(); + + emit("if("); + emitIROperand(context, t->getCondition()); + emit(")\n{\n"); + emitIRStmtsForBlocks( + context, + trueBlock, + afterBlock); + emit("}\n"); + + // Continue with the block after the `if` + block = afterBlock; + } + break; + + case kIROp_ifElse: + { + // Two-sided `if` statement + auto t = (IRIfElse*)terminator; + + auto trueBlock = t->getTrueBlock(); + auto falseBlock = t->getFalseBlock(); + auto afterBlock = t->getAfterBlock(); + + emit("if("); + emitIROperand(context, t->getCondition()); + emit(")\n{\n"); + emitIRStmtsForBlocks( + context, + trueBlock, + afterBlock); + emit("}\nelse\n{\n"); + emitIRStmtsForBlocks( + context, + falseBlock, + afterBlock); + emit("}\n"); + + // Continue with the block after the `if` + block = afterBlock; + } + break; + + case kIROp_loop: + { + // Header for a `while` or `for` loop + auto t = (IRLoop*)terminator; + + auto targetBlock = t->getTargetBlock(); + auto breakBlock = t->getBreakBlock(); + auto continueBlock = t->getContinueBlock(); + + if (auto loopControlDecoration = t->findDecoration<IRLoopControlDecoration>()) + { + switch (loopControlDecoration->mode) + { + case kIRLoopControl_Unroll: + emit("[unroll]\n"); + break; + + default: + break; + } + } + + // The challenging case for a loop is when + // there is a `continue` block that we + // need to deal with. + // + if (continueBlock == targetBlock) + { + // There is no continue block, so + // we only need to emit an endless + // loop and then manually `break` + // out of it in the right place(s) + emit("for(;;)\n{\n"); + + emitIRStmtsForBlocks( + context, + targetBlock, + nullptr); + + emit("}\n"); + } + else + { + // Okay, we've got a `continue` block, + // which means we really want to emit + // something akin to: + // + // for(;; <continueBlock>) { <bodyBlock> } + // + // In principle this isn't so bad, since the + // first case is just interVal [`continueBlock`, `targetBlock`) + // and the latter is the interval [`targetBlock`, `continueBlock`). + // + // The challenge of course is that a `for` statement + // only supports *expressions* in the continue part, + // and we might have expanded things into multiple + // instructions (especially if we inlined or desugared anything). + // + // There are a variety of ways we can support lowering this, + // but for now we are going to do something expedient + // that mimics what `fxc` seems to do: + // + // - Output loop body as `for(;;) { <bodyBlock> <continueBlock> }` + // - At any `continue` site, output `{ <continueBlock>; continue; }` + // + // This isn't ideal because it leads to code duplication, but + // it matches what `fxc` does so hopefully it will be the + // best option for our tests. + // + + emit("for(;;)\n{\n"); + + // TODO: Okay, we *said* we'd do this special + // handling of the `continue` sites, but + // we aren't actually setting anything up here... + // + + emitIRStmtsForBlocks( + context, + targetBlock, + nullptr); + + emit("}\n"); + + } + + // Continue with the block after the loop + block = breakBlock; + } + break; + + case kIROp_break: + emit("break;\n"); + return; + + case kIROp_continue: + emit("continue;\n"); + return; + + case kIROp_loopTest: + { + // Loop condition being tested + auto t = (IRLoopTest*)terminator; + + auto afterBlock = t->getTrueBlock(); + + emit("if("); + emitIROperand(context, t->getCondition()); + emit(")\n{} else break;\n"); + + // Continue with the block after the test + block = afterBlock; + } + break; + + case kIROp_unconditionalBranch: + { + // Unconditional branch as part of normal + // control flow. This is either a forward + // edge to the "next" block in an ordinary + // block, or a backward edge to the top + // of a loop. + auto t = (IRUnconditionalBranch*)terminator; + block = t->getTargetBlock(); + } + break; + + case kIROp_conditionalBranch: + SLANG_UNEXPECTED("terminator inst"); + return; + } + + // If we reach this point, then we've emitted + // one block, and we have a new block where + // control flow continues. + // + // We need to handle a special case here, + // when control flow jumps back to the + // starting block of the range we were + // asked to work with: + if (block == begin) return; + } + } + void emitIRFunc( EmitContext* context, IRFunc* func) @@ -4522,6 +4899,19 @@ emitDeclImpl(decl, nullptr); auto funcType = func->getType(); auto resultType = func->getResultType(); + // Deal with decorations that need + // to be emitted as attributes + if (auto threadGroupSizeDecoration = func->findDecoration<IRComputeThreadGroupSizeDecoration>()) + { + emit("[numthreads("); + for (int ii = 0; ii < 3; ++ii) + { + if (ii != 0) emit(", "); + Emit(threadGroupSizeDecoration->sizeAlongAxis[ii]); + } + emit(")]\n"); + } + auto name = getName(func); emitIRType(context, resultType, name); @@ -4535,6 +4925,8 @@ emitDeclImpl(decl, nullptr); auto paramName = getName(pp); emitIRType(context, pp->getType(), paramName); + + emitIRSemantics(context, pp); } emit(")"); @@ -4548,6 +4940,10 @@ emitDeclImpl(decl, nullptr); emit("\n{\n"); // Need to emit the operations in the blocks of the function + + emitIRStmtsForBlocks(context, func->getFirstBlock(), nullptr); + +#if 0 for( auto bb = func->getFirstBlock(); bb; bb = bb->getNextBlock() ) { // TODO: need to handle control flow and so forth... @@ -4556,6 +4952,7 @@ emitDeclImpl(decl, nullptr); emitIRInst(context, ii); } } +#endif emit("}\n"); } @@ -4687,7 +5084,8 @@ emitDeclImpl(decl, nullptr); IRVar* varDecl) { auto allocatedType = varDecl->getType(); - auto varType = ((IRPtrType*) allocatedType)->getValueType(); + auto varType = allocatedType->getValueType(); + auto addressSpace = allocatedType->getAddressSpace(); switch( varType->op ) { @@ -4706,6 +5104,16 @@ emitDeclImpl(decl, nullptr); emitIRVarModifiers(context, layout); + switch (addressSpace) + { + default: + break; + + case kIRAddressSpace_GroupShared: + emit("groupshared "); + break; + } + emitIRType(context, varType, getName(varDecl)); emitIRSemantics(context, varDecl); diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index 7c7d3fda0..0ca838ff5 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -29,7 +29,13 @@ __magic_type(HLSLByteAddressBufferType) struct ByteAddressBuffer __intrinsic_op uint4 Load4(int location, out uint status); }; -__generic<T> __magic_type(HLSLStructuredBufferType) struct StructuredBuffer +__generic<T> +__magic_type(HLSLStructuredBufferType) +__intrinsic_type(${{ + // TODO: we really need a simple way to write an "expression splice" + sb << kIROp_structuredBufferType; +}}) +struct StructuredBuffer { __intrinsic_op void GetDimensions( out uint numStructs, diff --git a/source/slang/hlsl.meta.slang.cpp b/source/slang/hlsl.meta.slang.cpp index a3744f620..49254ac60 100644 --- a/source/slang/hlsl.meta.slang.cpp +++ b/source/slang/hlsl.meta.slang.cpp @@ -29,7 +29,14 @@ sb << " __intrinsic_op uint4 Load4(int location);\n"; sb << " __intrinsic_op uint4 Load4(int location, out uint status);\n"; sb << "};\n"; sb << "\n"; -sb << "__generic<T> __magic_type(HLSLStructuredBufferType) struct StructuredBuffer\n"; +sb << "__generic<T>\n"; +sb << "__magic_type(HLSLStructuredBufferType)\n"; +sb << "__intrinsic_type("; + + // TODO: we really need a simple way to write an "expression splice" + sb << kIROp_structuredBufferType; +sb << ")\n"; +sb << "struct StructuredBuffer\n"; sb << "{\n"; sb << " __intrinsic_op void GetDimensions(\n"; sb << " out uint numStructs,\n"; diff --git a/source/slang/ir-inst-defs.h b/source/slang/ir-inst-defs.h index dab7aa3d7..deaf02d56 100644 --- a/source/slang/ir-inst-defs.h +++ b/source/slang/ir-inst-defs.h @@ -13,24 +13,29 @@ // Invalid operation: should not appear in valid code INST(Nop, nop, 0, 0) -INST(TypeType, type.type, 0, 0) -INST(VoidType, type.void, 0, 0) -INST(BlockType, type.block, 0, 0) -INST(VectorType, type.vector, 2, 0) -INST(MatrixType, matrixType, 3, 0) -INST(BoolType, type.bool, 0, 0) -INST(Float32Type, type.f32, 0, 0) -INST(Int32Type, type.i32, 0, 0) -INST(UInt32Type, type.u32, 0, 0) -INST(StructType, type.struct, 0, PARENT) -INST(FuncType, func_type, 0, 0) -INST(PtrType, ptr_type, 1, 0) -INST(TextureType, texture_type, 2, 0) -INST(SamplerType, sampler_type, 1, 0) -INST(ConstantBufferType, constant_buffer_type, 1, 0) -INST(TextureBufferType, texture_buffer_type, 1, 0) -INST(readWriteStructuredBufferType, readWriteStructuredBufferType, 1, 0) - +INST(TypeType, Type, 0, 0) +INST(VoidType, Void, 0, 0) +INST(BlockType, Block, 0, 0) +INST(VectorType, Vec, 2, 0) +INST(MatrixType, Mat, 3, 0) +INST(arrayType, Array, 2, 0) + +INST(BoolType, Bool, 0, 0) +INST(Float32Type, Float32, 0, 0) +INST(Int32Type, Int32, 0, 0) +INST(UInt32Type, UInt32, 0, 0) +INST(StructType, Struct, 0, PARENT) +INST(FuncType, Func, 0, 0) +INST(PtrType, Ptr, 1, 0) +INST(TextureType, Texture, 2, 0) +INST(SamplerType, SamplerState, 1, 0) +INST(ConstantBufferType, ConstantBuffer, 1, 0) +INST(TextureBufferType, TextureBuffer, 1, 0) + +INST(structuredBufferType, StructuredBuffer, 1, 0) +INST(readWriteStructuredBufferType, RWStructuredBuffer, 1, 0) + +INST(boolConst, boolConst, 0, 0) INST(IntLit, integer_constant, 0, 0) INST(FloatLit, float_constant, 0, 0) @@ -57,9 +62,48 @@ INST(FieldAddress, get_field_addr, 2, 0) INST(getElement, getElement, 2, 0) INST(getElementPtr, getElementPtr, 2, 0) +// A swizzle of a vector: +// +// %dst = swizzle %src %idx0 %idx1 ... +// +// where: +// - `src` is a vector<T,N> +// - `dst` is a vector<T,M> +// - `idx0` through `idx[M-1]` are literal integers +// +INST(swizzle, swizzle, 1, 0) + +// Setting a vector via swizzle +// +// %dst = swizzle %base %src %idx0 %idx1 ... +// +// where: +// - `base` is a vector<T,N> +// - `dst` is a vector<T,N> +// - `src` is a vector<T,M> +// - `idx0` through `idx[M-1]` are literal integers +// +// The semantics of the op is: +// +// dst = base; +// for(ii : 0 ... M-1 ) +// dst[ii] = src[idx[ii]]; +// +INST(swizzleSet, swizzleSet, 2, 0) + + INST(ReturnVal, return_val, 1, 0) INST(ReturnVoid, return_void, 1, 0) +INST(unconditionalBranch, unconditionalBranch, 1, 0) +INST(break, break, 1, 0) +INST(continue, continue, 1, 0) +INST(loop, loop, 3, 0) + +INST(conditionalBranch, conditionalBranch, 1, 0) +INST(if, if, 3, 0) +INST(ifElse, ifElse, 4, 0) +INST(loopTest, loopTest, 3, 0) INST(Add, add, 2, 0) INST(Sub, sub, 2, 0) @@ -139,6 +183,8 @@ INST(Sample, sample, 3, 0) INST(SampleGrad, sampleGrad, 4, 0) +INST(GroupMemoryBarrierWithGroupSync, GroupMemoryBarrierWithGroupSync, 0, 0) + PSEUDO_INST(Pos) PSEUDO_INST(PreInc) diff --git a/source/slang/ir-insts.h b/source/slang/ir-insts.h new file mode 100644 index 000000000..846e87bb6 --- /dev/null +++ b/source/slang/ir-insts.h @@ -0,0 +1,569 @@ +// ir-insts.h +#ifndef SLANG_IR_INSTS_H_INCLUDED +#define SLANG_IR_INSTS_H_INCLUDED + +// This file extends the core definitions in `ir.h` +// with a wider variety of concrete instructions, +// and a "builder" abstraction. +// +// TODO: the builder probably needs its own file. + +#include "compiler.h" +#include "ir.h" + +namespace Slang { + +class Decl; + +// Associates an IR-level decoration with a source declaration +// in the high-level AST, that can be used to extract +// additional information that informs code emission. +struct IRHighLevelDeclDecoration : IRDecoration +{ + enum { kDecorationOp = kIRDecorationOp_HighLevelDecl }; + + Decl* decl; +}; + +// Associates an IR-level decoration with a source layout +struct IRLayoutDecoration : IRDecoration +{ + enum { kDecorationOp = kIRDecorationOp_Layout }; + + Layout* layout; +}; + +// Identifies a function as an entry point for some stage +struct IREntryPointDecoration : IRDecoration +{ + enum { kDecorationOp = kIRDecorationOp_EntryPoint }; + + Profile profile; +}; + +// Associates a compute-shader entry point function +// with a thread-group size. +struct IRComputeThreadGroupSizeDecoration : IRDecoration +{ + enum { kDecorationOp = kIRDecorationOp_ComputeThreadGroupSize }; + + UInt sizeAlongAxis[3]; +}; + +enum IRLoopControl +{ + kIRLoopControl_Unroll, +}; + +struct IRLoopControlDecoration : IRDecoration +{ + enum { kDecorationOp = kIRDecorationOp_LoopControl }; + + IRLoopControl mode; +}; + +// Types + +struct IRVectorType : IRType +{ + IRUse elementType; + IRUse elementCount; + + IRType* getElementType() { return (IRType*) elementType.usedValue; } + IRInst* getElementCount() { return elementCount.usedValue; } +}; + +struct IRMatrixType : IRType +{ + IRUse elementType; + IRUse rowCount; + IRUse columnCount; + + IRType* getElementType() { return (IRType*) elementType.usedValue; } + IRInst* getRowCount() { return rowCount.usedValue; } + IRInst* getColumnCount() { return columnCount.usedValue; } +}; + +struct IRArrayType : IRType +{ + IRUse elementType; + IRUse elementCount; + + IRType* getElementType() { return (IRType*) elementType.usedValue; } + IRInst* getElementCount() { return elementCount.usedValue; } +}; + + + +struct IRFuncType : IRType +{ + IRUse resultType; + // parameter tyeps are varargs... + + IRType* getResultType() { return (IRType*) resultType.usedValue; } + UInt getParamCount() + { + return getArgCount() - 2; + } + IRType* getParamType(UInt index) + { + return (IRType*) getArg(2 + index); + } +}; + +// Address spaces for IR pointers +enum IRAddressSpace : UInt +{ + // A default address space for things like local variables + kIRAddressSpace_Default, + + // Address space for HLSL `groupshared` allocations + kIRAddressSpace_GroupShared, +}; + +struct IRPtrType : IRType +{ + IRUse valueType; + IRUse addressSpace; + + IRType* getValueType() { return (IRType*) valueType.usedValue; } + + IRAddressSpace getAddressSpace() + { + return IRAddressSpace( + ((IRConstant*)addressSpace.usedValue)->u.intVal); + } +}; + +struct IRTextureType : IRType +{ + IRUse flavor; + IRUse elementType; + + IRIntegerValue getFlavor() { return ((IRConstant*) flavor.usedValue)->u.intVal; } + IRType* getElementType() { return (IRType*) elementType.usedValue; } +}; + +struct IRBufferType : IRType +{ + IRUse elementType; + IRType* getElementType() { return (IRType*) elementType.usedValue; } +}; + + +struct IRUniformBufferType : IRType +{ + IRUse elementType; + IRType* getElementType() { return (IRType*) elementType.usedValue; } +}; + +struct IRConstantBufferType : IRUniformBufferType {}; +struct IRTextureBufferType : IRUniformBufferType {}; + +struct IRCall : IRInst +{ + IRUse func; +}; + +struct IRLoad : IRInst +{ + IRUse ptr; +}; + +struct IRStore : IRInst +{ + IRUse ptr; + IRUse val; +}; + +struct IRStructField; +struct IRFieldExtract : IRInst +{ + IRUse base; + IRUse field; + + IRInst* getBase() { return base.usedValue; } + IRStructField* getField() { return (IRStructField*) field.usedValue; } +}; + +struct IRFieldAddress : IRInst +{ + IRUse base; + IRUse field; + + IRInst* getBase() { return base.usedValue; } + IRStructField* getField() { return (IRStructField*) field.usedValue; } +}; + +// Terminators + +struct IRReturn : IRTerminatorInst +{}; + +struct IRReturnVal : IRReturn +{ + IRUse val; + + IRInst* getVal() { return val.usedValue; } +}; + +struct IRReturnVoid : IRReturn +{}; + +struct IRBlock; + +struct IRUnconditionalBranch : IRTerminatorInst +{ + IRUse block; + + IRBlock* getTargetBlock() { return (IRBlock*)block.usedValue; } +}; + +// Special cases of unconditional branch, to handle +// structured control flow: +struct IRBreak : IRUnconditionalBranch {}; +struct IRContinue : IRUnconditionalBranch {}; + +// The start of a loop is a special control-flow +// instruction, that records relevant information +// about the loop structure: +struct IRLoop : IRUnconditionalBranch +{ + // The next block after the loop, which + // is where we expect control flow to + // re-converge, and also where a + // `break` will target. + IRUse breakBlock; + + // The block where control flow will go + // on a `continue`. + IRUse continueBlock; + + IRBlock* getBreakBlock() { return (IRBlock*)breakBlock.usedValue; } + IRBlock* getContinueBlock() { return (IRBlock*)continueBlock.usedValue; } +}; + +struct IRConditionalBranch : IRTerminatorInst +{ + IRUse condition; + IRUse trueBlock; + IRUse falseBlock; + + IRInst* getCondition() { return condition.usedValue; } + IRBlock* getTrueBlock() { return (IRBlock*)trueBlock.usedValue; } + IRBlock* getFalseBlock() { return (IRBlock*)falseBlock.usedValue; } +}; + +// A conditional branch that represent the test inside a loop +struct IRLoopTest : IRConditionalBranch +{ +}; + +// A conditional branch that represents a one-sided `if`: +// +// if( <condition> ) { <trueBlock> } +// <falseBlock> +struct IRIf : IRConditionalBranch +{ + IRBlock* getAfterBlock() { return getFalseBlock(); } +}; + +// A conditional branch that represents a two-sided `if`: +// +// if( <condition> ) { <trueBlock> } +// else { <falseBlock> } +// <afterBlock> +// +struct IRIfElse : IRConditionalBranch +{ + IRUse afterBlock; + + IRBlock* getAfterBlock() { return (IRBlock*)afterBlock.usedValue; } +}; + +struct IRSwizzle : IRReturn +{ + IRUse base; + + IRInst* getBase() { return base.usedValue; } + UInt getElementCount() + { + return getArgCount() - 2; + } + IRInst* getElementIndex(UInt index) + { + return getArg(index + 2); + } +}; + +struct IRSwizzleSet : IRReturn +{ + IRUse base; + IRUse source; + + IRInst* getBase() { return base.usedValue; } + IRInst* getSource() { return source.usedValue; } + UInt getElementCount() + { + return getArgCount() - 3; + } + IRInst* getElementIndex(UInt index) + { + return getArg(index + 3); + } +}; + +// "Parent" Instructions (Declarations) + +struct IRStructField : IRInst +{ + IRType* getFieldType() { return (IRType*) type.usedValue; } + + IRStructField* getNextField() { return (IRStructField*) nextInst; } +}; + +struct IRStructDecl : IRParentInst +{ + IRStructField* getFirstField() { return (IRStructField*) firstChild; } + IRStructField* getLastField() { return (IRStructField*) lastChild; } +}; + + +struct IRVar : IRInst +{ + IRPtrType* getType() + { + return (IRPtrType*)type.usedValue; + } +}; + + +// Description of an instruction to be used for global value numbering +struct IRInstKey +{ + IRInst* inst; + + int GetHashCode(); +}; + +bool operator==(IRInstKey const& left, IRInstKey const& right); + +struct IRConstantKey +{ + IRConstant* inst; + + int GetHashCode(); +}; +bool operator==(IRConstantKey const& left, IRConstantKey const& right); + +struct SharedIRBuilder +{ + // The module that will own all of the IR + IRModule* module; + + Dictionary<IRInstKey, IRInst*> globalValueNumberingMap; + Dictionary<IRConstantKey, IRConstant*> constantMap; +}; + +struct IRBuilder +{ + // Shared state for all IR builders working on the same module + SharedIRBuilder* shared; + + IRModule* getModule() { return shared->module; } + + // The parent instruction to add children to. + IRParentInst* parentInst; + + void addInst(IRParentInst* parent, IRInst* inst); + void addInst(IRInst* inst); + + IRType* getBaseType(BaseType flavor); + IRType* getBoolType(); + IRType* getVectorType(IRType* elementType, IRValue* elementCount); + IRType* getMatrixType( + IRType* elementType, + IRValue* rowCount, + IRValue* columnCount); + IRType* getArrayType(IRType* elementType, IRValue* elementCount); + IRType* getArrayType(IRType* elementType); + + IRType* getTypeType(); + IRType* getVoidType(); + IRType* getBlockType(); + + IRType* getIntrinsicType( + IROp op, + UInt argCount, + IRValue* const* args); + + IRStructDecl* createStructType(); + IRStructField* createStructField(IRType* fieldType); + + IRType* getFuncType( + UInt paramCount, + IRType* const* paramTypes, + IRType* resultType); + + IRType* getPtrType( + IRType* valueType, + IRAddressSpace addressSpace); + + IRType* getPtrType( + IRType* valueType); + + IRValue* getBoolValue(bool value); + IRValue* getIntValue(IRType* type, IRIntegerValue value); + IRValue* getFloatValue(IRType* type, IRFloatingPointValue value); + + IRInst* emitCallInst( + IRType* type, + IRValue* func, + UInt argCount, + IRValue* const* args); + + IRInst* emitIntrinsicInst( + IRType* type, + IROp op, + UInt argCount, + IRValue* const* args); + + IRInst* emitConstructorInst( + IRType* type, + UInt argCount, + IRValue* const* args); + + IRModule* createModule(); + + IRFunc* createFunc(); + + IRBlock* createBlock(); + IRBlock* emitBlock(); + + IRParam* emitParam( + IRType* type); + + IRVar* emitVar( + IRType* type, + IRAddressSpace addressSpace); + + IRVar* emitVar( + IRType* type); + + IRInst* emitLoad( + IRValue* ptr); + + IRInst* emitStore( + IRValue* dstPtr, + IRValue* srcVal); + + IRInst* emitFieldExtract( + IRType* type, + IRValue* base, + IRStructField* field); + + IRInst* emitFieldAddress( + IRType* type, + IRValue* basePtr, + IRStructField* field); + + IRInst* emitElementExtract( + IRType* type, + IRValue* base, + IRValue* index); + + IRInst* emitElementAddress( + IRType* type, + IRValue* basePtr, + IRValue* index); + + IRInst* emitSwizzle( + IRType* type, + IRValue* base, + UInt elementCount, + IRValue* const* elementIndices); + + IRInst* emitSwizzle( + IRType* type, + IRValue* base, + UInt elementCount, + UInt const* elementIndices); + + IRInst* emitSwizzleSet( + IRType* type, + IRValue* base, + IRValue* source, + UInt elementCount, + IRValue* const* elementIndices); + + IRInst* emitSwizzleSet( + IRType* type, + IRValue* base, + IRValue* source, + UInt elementCount, + UInt const* elementIndices); + + IRInst* emitReturn( + IRValue* val); + + IRInst* emitReturn(); + + IRInst* emitBranch( + IRBlock* block); + + IRInst* emitBreak( + IRBlock* target); + + IRInst* emitContinue( + IRBlock* target); + + IRInst* emitLoop( + IRBlock* target, + IRBlock* breakBlock, + IRBlock* continueBlock); + + IRInst* emitBranch( + IRValue* val, + IRBlock* trueBlock, + IRBlock* falseBlock); + + IRInst* emitIf( + IRValue* val, + IRBlock* trueBlock, + IRBlock* afterBlock); + + IRInst* emitIfElse( + IRValue* val, + IRBlock* trueBlock, + IRBlock* falseBlock, + IRBlock* afterBlock); + + IRInst* emitLoopTest( + IRValue* val, + IRBlock* bodyBlock, + IRBlock* breakBlock); + + IRDecoration* addDecorationImpl( + IRInst* inst, + UInt decorationSize, + IRDecorationOp op); + + template<typename T> + T* addDecoration(IRInst* inst, IRDecorationOp op) + { + return (T*) addDecorationImpl(inst, sizeof(T), op); + } + + template<typename T> + T* addDecoration(IRInst* inst) + { + return (T*) addDecorationImpl(inst, sizeof(T), IRDecorationOp(T::kDecorationOp)); + } + + IRHighLevelDeclDecoration* addHighLevelDeclDecoration(IRInst* inst, Decl* decl); + IRLayoutDecoration* addLayoutDecoration(IRInst* inst, Layout* layout); +}; + +} + +#endif diff --git a/source/slang/ir.cpp b/source/slang/ir.cpp index 3a6410125..c48880cab 100644 --- a/source/slang/ir.cpp +++ b/source/slang/ir.cpp @@ -1,5 +1,6 @@ // ir.cpp #include "ir.h" +#include "ir-insts.h" #include "../core/basic.h" @@ -65,7 +66,12 @@ namespace Slang return nullptr; } - // + // IRFunc + + IRType* IRFunc::getResultType() { return getType()->getResultType(); } + UInt IRFunc::getParamCount() { return getType()->getParamCount(); } + IRType* IRFunc::getParamType(UInt index) { return getType()->getParamType(index); } + IRParam* IRFunc::getFirstParam() { @@ -81,6 +87,8 @@ namespace Slang return (IRParam*) firstInst; } + // IRParam + IRParam* IRParam::getNextParam() { auto next = nextInst; @@ -94,6 +102,36 @@ namespace Slang // + bool isTerminatorInst(IROp op) + { + switch (op) + { + default: + return false; + + case kIROp_ReturnVal: + case kIROp_ReturnVoid: + case kIROp_unconditionalBranch: + case kIROp_conditionalBranch: + case kIROp_break: + case kIROp_continue: + case kIROp_loop: + case kIROp_if: + case kIROp_ifElse: + case kIROp_loopTest: + return true; + } + } + + bool isTerminatorInst(IRInst* inst) + { + if (!inst) return false; + return isTerminatorInst(inst->op); + } + + + // + // Add an instruction to a specific parent void IRBuilder::addInst(IRParentInst* parent, IRInst* inst) { @@ -306,6 +344,28 @@ namespace Slang varArgs); } + template<typename T> + static T* createInstWithTrailingArgs( + IRBuilder* builder, + IROp op, + IRType* type, + IRValue* arg1, + UInt varArgCount, + IRValue* const* varArgs) + { + IRValue* fixedArgs[] = { arg1 }; + UInt fixedArgCount = sizeof(fixedArgs) / sizeof(fixedArgs[0]); + + return (T*)createInstImpl( + builder, + sizeof(T) + varArgCount * sizeof(IRUse), + op, + type, + fixedArgCount, + fixedArgs, + varArgCount, + varArgs); + } // bool operator==(IRInstKey const& left, IRInstKey const& right) @@ -637,6 +697,8 @@ namespace Slang { switch( flavor ) { + case BaseType::Void: return getVoidType(); + case BaseType::Bool: return getBaseTypeImpl(this, kIROp_BoolType); case BaseType::Float: return getBaseTypeImpl(this, kIROp_Float32Type); case BaseType::Int: return getBaseTypeImpl(this, kIROp_Int32Type); @@ -677,6 +739,34 @@ namespace Slang columnCount); } + IRType* IRBuilder::getArrayType(IRType* elementType, IRValue* elementCount) + { + // The client requests an unsized array by passing `nullptr` for + // the `elementCount`. + // + // We currently encode an unsized array as an ordinary array with + // zero elements. TODO: carefully consider this choice. + if (!elementCount) + { + elementCount = getIntValue( + getBaseType(BaseType::Int), + 0); + } + + return findOrEmitInst<IRArrayType>( + this, + kIROp_arrayType, + getTypeType(), + elementType, + elementCount); + } + + IRType* IRBuilder::getArrayType(IRType* elementType) + { + return getArrayType(elementType, nullptr); + } + + IRType* IRBuilder::getTypeType() { return findOrEmitInst<IRType>( @@ -753,21 +843,38 @@ namespace Slang } IRType* IRBuilder::getPtrType( - IRType* valueType) + IRType* valueType, + IRAddressSpace addressSpace) { + auto uintType = getBaseType(BaseType::UInt); + auto irAddressSpace = getIntValue(uintType, addressSpace); + auto inst = findOrEmitInst<IRPtrType>( this, kIROp_PtrType, getTypeType(), - 1, - (IRValue* const*) &valueType); + valueType, + irAddressSpace); return inst; } - IRValue* IRBuilder::getBoolValue(bool value) + IRType* IRBuilder::getPtrType( + IRType* valueType) { - SLANG_UNIMPLEMENTED_X("IR"); + return getPtrType(valueType, kIRAddressSpace_Default); + } + + + IRValue* IRBuilder::getBoolValue(bool inValue) + { + IRIntegerValue value = inValue; + return findOrEmitConstant( + this, + kIROp_boolConst, + getBoolType(), + sizeof(value), + &value); } IRValue* IRBuilder::getIntValue(IRType* type, IRIntegerValue value) @@ -883,9 +990,10 @@ namespace Slang } IRVar* IRBuilder::emitVar( - IRType* type) + IRType* type, + IRAddressSpace addressSpace) { - auto allocatedType = getPtrType(type); + auto allocatedType = getPtrType(type, addressSpace); auto inst = createInst<IRVar>( this, kIROp_Var, @@ -894,6 +1002,13 @@ namespace Slang return inst; } + + IRVar* IRBuilder::emitVar( + IRType* type) + { + return emitVar(type, kIRAddressSpace_Default); + } + IRInst* IRBuilder::emitLoad( IRValue* ptr) { @@ -996,6 +1111,83 @@ namespace Slang return inst; } + IRInst* IRBuilder::emitSwizzle( + IRType* type, + IRValue* base, + UInt elementCount, + IRValue* const* elementIndices) + { + auto inst = createInstWithTrailingArgs<IRSwizzle>( + this, + kIROp_swizzle, + type, + base, + elementCount, + elementIndices); + + addInst(inst); + return inst; + } + + IRInst* IRBuilder::emitSwizzle( + IRType* type, + IRValue* base, + UInt elementCount, + UInt const* elementIndices) + { + auto intType = getBaseType(BaseType::Int); + + IRValue* irElementIndices[4]; + for (UInt ii = 0; ii < elementCount; ++ii) + { + irElementIndices[ii] = getIntValue(intType, elementIndices[ii]); + } + + return emitSwizzle(type, base, elementCount, irElementIndices); + } + + + IRInst* IRBuilder::emitSwizzleSet( + IRType* type, + IRValue* base, + IRValue* source, + UInt elementCount, + IRValue* const* elementIndices) + { + IRValue* fixedArgs[] = { base, source }; + UInt fixedArgCount = sizeof(fixedArgs) / sizeof(fixedArgs[0]); + + auto inst = createInstWithTrailingArgs<IRSwizzleSet>( + this, + kIROp_swizzleSet, + type, + fixedArgCount, + fixedArgs, + elementCount, + elementIndices); + + addInst(inst); + return inst; + } + + IRInst* IRBuilder::emitSwizzleSet( + IRType* type, + IRValue* base, + IRValue* source, + UInt elementCount, + UInt const* elementIndices) + { + auto intType = getBaseType(BaseType::Int); + + IRValue* irElementIndices[4]; + for (UInt ii = 0; ii < elementCount; ++ii) + { + irElementIndices[ii] = getIntValue(intType, elementIndices[ii]); + } + + return emitSwizzleSet(type, base, source, elementCount, irElementIndices); + } + IRInst* IRBuilder::emitReturn( IRValue* val) { @@ -1018,6 +1210,133 @@ namespace Slang return inst; } + IRInst* IRBuilder::emitBranch( + IRBlock* block) + { + auto inst = createInst<IRUnconditionalBranch>( + this, + kIROp_unconditionalBranch, + getVoidType(), + block); + addInst(inst); + return inst; + } + + IRInst* IRBuilder::emitBreak( + IRBlock* target) + { + auto inst = createInst<IRBreak>( + this, + kIROp_break, + getVoidType(), + target); + addInst(inst); + return inst; + } + + IRInst* IRBuilder::emitContinue( + IRBlock* target) + { + auto inst = createInst<IRContinue>( + this, + kIROp_continue, + getVoidType(), + target); + addInst(inst); + return inst; + } + + IRInst* IRBuilder::emitLoop( + IRBlock* target, + IRBlock* breakBlock, + IRBlock* continueBlock) + { + IRInst* args[] = { target, breakBlock, continueBlock }; + UInt argCount = sizeof(args) / sizeof(args[0]); + + auto inst = createInst<IRLoop>( + this, + kIROp_loop, + getVoidType(), + argCount, + args); + addInst(inst); + return inst; + } + + IRInst* IRBuilder::emitBranch( + IRValue* val, + IRBlock* trueBlock, + IRBlock* falseBlock) + { + IRInst* args[] = { val, trueBlock, falseBlock }; + UInt argCount = sizeof(args) / sizeof(args[0]); + + auto inst = createInst<IRConditionalBranch>( + this, + kIROp_conditionalBranch, + getVoidType(), + argCount, + args); + addInst(inst); + return inst; + } + + IRInst* IRBuilder::emitIf( + IRValue* val, + IRBlock* trueBlock, + IRBlock* afterBlock) + { + IRInst* args[] = { val, trueBlock, afterBlock }; + UInt argCount = sizeof(args) / sizeof(args[0]); + + auto inst = createInst<IRIf>( + this, + kIROp_if, + getVoidType(), + argCount, + args); + addInst(inst); + return inst; + } + + IRInst* IRBuilder::emitIfElse( + IRValue* val, + IRBlock* trueBlock, + IRBlock* falseBlock, + IRBlock* afterBlock) + { + IRInst* args[] = { val, trueBlock, falseBlock, afterBlock }; + UInt argCount = sizeof(args) / sizeof(args[0]); + + auto inst = createInst<IRIfElse>( + this, + kIROp_ifElse, + getVoidType(), + argCount, + args); + addInst(inst); + return inst; + } + + IRInst* IRBuilder::emitLoopTest( + IRValue* val, + IRBlock* bodyBlock, + IRBlock* breakBlock) + { + IRInst* args[] = { val, bodyBlock, breakBlock }; + UInt argCount = sizeof(args) / sizeof(args[0]); + + auto inst = createInst<IRLoopTest>( + this, + kIROp_loopTest, + getVoidType(), + argCount, + args); + addInst(inst); + return inst; + } + IRDecoration* IRBuilder::addDecorationImpl( IRInst* inst, UInt decorationSize, @@ -1053,7 +1372,7 @@ namespace Slang struct IRDumpContext { - FILE* file; + StringBuilder* builder; int indent; }; @@ -1061,14 +1380,16 @@ namespace Slang IRDumpContext* context, char const* text) { - fprintf(context->file, "%s", text); + context->builder->append(text); } static void dump( IRDumpContext* context, UInt val) { - fprintf(context->file, "%llu", (unsigned long long)val); + context->builder->append(val); + +// fprintf(context->file, "%llu", (unsigned long long)val); } static void dumpIndent( @@ -1076,7 +1397,7 @@ namespace Slang { for (int ii = 0; ii < context->indent; ++ii) { - dump(context, " "); + dump(context, "\t"); } } @@ -1088,79 +1409,318 @@ namespace Slang { dump(context, "<null>"); } - else + else if(inst->id) { dump(context, "%"); dump(context, inst->id); } + else + { + dump(context, "_"); + } } - static void dumpInst( + static void dumpType( + IRDumpContext* context, + IRType* type); + + static void dumpOperand( IRDumpContext* context, IRInst* inst) { - dumpIndent(context); if (!inst) { - dump(context, "<null>"); + dump(context, "undef"); + return; } - // TODO: need to display a name for the result... + switch (inst->op) + { + case kIROp_IntLit: + dump(context, ((IRConstant*)inst)->u.intVal); + return; - auto op = inst->op; - auto opInfo = &kIROpInfos[op]; + case kIROp_FloatLit: + dump(context, ((IRConstant*)inst)->u.floatVal); + return; - if (inst->id) + case kIROp_boolConst: + dump(context, ((IRConstant*)inst)->u.intVal ? "true" : "false"); + return; + + case kIROp_TypeType: + dumpType(context, (IRType*)inst); + return; + + default: + break; + } + + auto type = inst->getType(); + if (type) { - dumpID(context, inst); - dump(context, " = "); + switch (type->op) + { + case kIROp_TypeType: + dumpType(context, (IRType*)inst); + return; + + default: + break; + } } - dump(context, opInfo->name); - // TODO: dump operands - uint32_t argCount = inst->argCount; - for (uint32_t ii = 0; ii < argCount; ++ii) + dumpID(context, inst); + } + + static void dumpType( + IRDumpContext* context, + IRType* type) + { + if (!type) { - if (ii != 0) - dump(context, ", "); - else + dumpID(context, type); + return; + } + + auto op = type->op; + auto opInfo = kIROpInfos[op]; + + switch (op) + { + case kIROp_StructType: + dumpID(context, type); + break; + + default: { - dump(context, " "); + dump(context, opInfo.name); + UInt argCount = type->getArgCount(); + + if (argCount > 1) + { + dump(context, "<"); + for (UInt aa = 1; aa < argCount; ++aa) + { + if (aa != 1) dump(context, ","); + dumpOperand(context, type->getArg(aa)); + + } + dump(context, ">"); + } } + break; + } + } - auto argVal = inst->getArgs()[ii].usedValue; + static void dumpInstTypeClause( + IRDumpContext* context, + IRType* type) + { + dump(context, "\t: "); + dumpType(context, type); - // TODO: actually print the damn operand... + } - dumpID(context, argVal); - } + static void dumpInst( + IRDumpContext* context, + IRInst* inst); - dump(context, "\n"); + static void dumpChildrenRaw( + IRDumpContext* context, + IRParentInst* parent) + { + for (auto ii = parent->firstChild; ii; ii = ii->nextInst) + { + dumpInst(context, ii); + } + } + static void dumpChildren( + IRDumpContext* context, + IRInst* inst) + { + auto op = inst->op; + auto opInfo = &kIROpInfos[op]; if (opInfo->flags & kIROpFlag_Parent) { dumpIndent(context); dump(context, "{\n"); context->indent++; auto parent = (IRParentInst*)inst; - for (auto ii = parent->firstChild; ii; ii = ii->nextInst) - { - dumpInst(context, ii); - } + dumpChildrenRaw(context, parent); context->indent--; dumpIndent(context); dump(context, "}\n"); } } + static void dumpInst( + IRDumpContext* context, + IRInst* inst) + { + if (!inst) + { + dumpIndent(context); + dump(context, "<null>"); + return; + } + + auto op = inst->op; + + // There are several ops we want to special-case here, + // so that they will be more pleasant to look at. + // + switch (op) + { + case kIROp_Module: + dumpIndent(context); + dump(context, "module\n"); + dumpChildren(context, inst); + return; + + case kIROp_Func: + { + IRFunc* func = (IRFunc*)inst; + dump(context, "\n"); + dumpIndent(context); + dump(context, "func "); + dumpID(context, func); + dump(context, "(\n"); + context->indent++; + for (auto pp = func->getFirstParam(); pp; pp = pp->getNextParam()) + { + if (pp != func->getFirstParam()) + dump(context, ",\n"); + + dumpIndent(context); + dump(context, "param "); + dumpID(context, pp); + dumpInstTypeClause(context, pp->getType()); + } + context->indent--; + dump(context, ")\n"); + + dumpIndent(context); + dump(context, "{\n"); + context->indent++; + + for (auto bb = func->getFirstBlock(); bb; bb = bb->getNextBlock()) + { + if (bb != func->getFirstBlock()) + dump(context, "\n"); + dumpInst(context, bb); + } + + context->indent--; + dump(context, "}\n"); + } + return; + + case kIROp_TypeType: + case kIROp_Param: + case kIROp_IntLit: + case kIROp_FloatLit: + case kIROp_boolConst: + // Don't dump here + return; + + case kIROp_Block: + { + IRBlock* block = (IRBlock*)inst; + + context->indent--; + dump(context, "block "); + dumpID(context, block); + dump(context, ":\n"); + context->indent++; + + dumpChildrenRaw(context, block); + } + return; + + default: + break; + } - void dumpIR(IRModule* module) + // We also want to special-case based on the *type* + // of the instruction + auto type = inst->getType(); + if (type && type->op == kIROp_TypeType) + { + // We probably don't want to print most types + // when producing "friendly" output. + switch (type->op) + { + case kIROp_StructType: + break; + + default: + return; + } + } + + + // Okay, we have a seemingly "ordinary" op now + dumpIndent(context); + + auto opInfo = &kIROpInfos[op]; + + if (type && type->op == kIROp_TypeType) + { + dump(context, "type "); + dumpID(context, inst); + dump(context, "\t= "); + } + else if (type && type->op == kIROp_VoidType) + { + } + else + { + dump(context, "let "); + dumpID(context, inst); + dumpInstTypeClause(context, type); + dump(context, "\t= "); + } + + + dump(context, opInfo->name); + + uint32_t argCount = inst->argCount; + dump(context, "("); + for (uint32_t ii = 1; ii < argCount; ++ii) + { + if (ii != 1) + dump(context, ", "); + + auto argVal = inst->getArgs()[ii].usedValue; + + dumpOperand(context, argVal); + } + dump(context, ")"); + + dump(context, "\n"); + + // The instruction might have children, + // so we need to handle those here + dumpChildren(context, inst); + } + + void printSlangIRAssembly(StringBuilder& builder, IRModule* module) { IRDumpContext context; - context.file = stderr; + context.builder = &builder; context.indent = 0; - dumpInst(&context, module); + dumpChildrenRaw(&context, module); } + String getSlangIRAssembly(IRModule* module) + { + StringBuilder sb; + printSlangIRAssembly(sb, module); + return sb; + } + + } diff --git a/source/slang/ir.h b/source/slang/ir.h index 6e4a25fe1..9b271b09b 100644 --- a/source/slang/ir.h +++ b/source/slang/ir.h @@ -9,7 +9,6 @@ #include "../core/basic.h" - namespace Slang { // TODO(tfoley): We should ditch this enumeration @@ -143,6 +142,9 @@ enum IRDecorationOp : uint16_t { kIRDecorationOp_HighLevelDecl, kIRDecorationOp_Layout, + kIRDecorationOp_EntryPoint, + kIRDecorationOp_ComputeThreadGroupSize, + kIRDecorationOp_LoopControl, }; // A "decoration" that gets applied to an instruction. @@ -220,27 +222,6 @@ struct IRInst // All existing uses of `IRValue` should move to `IRInst` typedef IRInst IRValue; -class Decl; - -// Associates an IR-level decoration with a source declaration -// in the high-level AST, that can be used to extract -// additional information that informs code emission. -struct IRHighLevelDeclDecoration : IRDecoration -{ - enum { kDecorationOp = kIRDecorationOp_HighLevelDecl }; - - Decl* decl; -}; - -// Associates an IR-level decoration with a source layout -struct IRLayoutDecoration : IRDecoration -{ - enum { kDecorationOp = kIRDecorationOp_Layout }; - - Layout* layout; -}; - - typedef long long IRIntegerValue; typedef double IRFloatingPointValue; @@ -266,126 +247,13 @@ struct IRType : IRInst { }; -struct IRVectorType : IRType -{ - IRUse elementType; - IRUse elementCount; - - IRType* getElementType() { return (IRType*) elementType.usedValue; } - IRInst* getElementCount() { return elementCount.usedValue; } -}; - -struct IRMatrixType : IRType -{ - IRUse elementType; - IRUse rowCount; - IRUse columnCount; - - IRType* getElementType() { return (IRType*) elementType.usedValue; } - IRInst* getRowCount() { return rowCount.usedValue; } - IRInst* getColumnCount() { return columnCount.usedValue; } -}; - -struct IRFuncType : IRType -{ - IRUse resultType; - // parameter tyeps are varargs... - - IRType* getResultType() { return (IRType*) resultType.usedValue; } - UInt getParamCount() - { - return getArgCount() - 2; - } - IRType* getParamType(UInt index) - { - return (IRType*) getArg(2 + index); - } -}; - -struct IRPtrType : IRType -{ - IRUse valueType; - - IRType* getValueType() { return (IRType*) valueType.usedValue; } -}; - -struct IRTextureType : IRType -{ - IRUse flavor; - IRUse elementType; - - IRIntegerValue getFlavor() { return ((IRConstant*) flavor.usedValue)->u.intVal; } - IRType* getElementType() { return (IRType*) elementType.usedValue; } -}; - -struct IRBufferType : IRType -{ - IRUse elementType; - IRType* getElementType() { return (IRType*) elementType.usedValue; } -}; - - -struct IRUniformBufferType : IRType -{ - IRUse elementType; - IRType* getElementType() { return (IRType*) elementType.usedValue; } -}; - -struct IRConstantBufferType : IRUniformBufferType {}; -struct IRTextureBufferType : IRUniformBufferType {}; - -struct IRCall : IRInst -{ - IRUse func; -}; - -struct IRLoad : IRInst -{ - IRUse ptr; -}; - -struct IRStore : IRInst -{ - IRUse ptr; - IRUse val; -}; - -struct IRStructField; -struct IRFieldExtract : IRInst -{ - IRUse base; - IRUse field; - - IRInst* getBase() { return base.usedValue; } - IRStructField* getField() { return (IRStructField*) field.usedValue; } -}; - -struct IRFieldAddress : IRInst -{ - IRUse base; - IRUse field; - - IRInst* getBase() { return base.usedValue; } - IRStructField* getField() { return (IRStructField*) field.usedValue; } -}; - - // A instruction that ends a basic block (usually because of control flow) struct IRTerminatorInst : IRInst {}; -struct IRReturn : IRTerminatorInst -{}; - -struct IRReturnVal : IRReturn -{ - IRUse val; - - IRInst* getVal() { return val.usedValue; } -}; +bool isTerminatorInst(IROp op); +bool isTerminatorInst(IRInst* inst); -struct IRReturnVoid : IRReturn -{}; // A parent instruction contains a sequence of other instructions // @@ -398,20 +266,6 @@ struct IRParentInst : IRInst IRInst* lastChild; }; -struct IRStructField : IRInst -{ - IRType* getFieldType() { return (IRType*) type.usedValue; } - - IRStructField* getNextField() { return (IRStructField*) nextInst; } -}; - -struct IRStructDecl : IRParentInst -{ - IRStructField* getFirstField() { return (IRStructField*) firstChild; } - IRStructField* getLastField() { return (IRStructField*) lastChild; } -}; - - // A basic block is a parent instruction that adds the constraint // that all the children need to be "ordinary" instructions (so // no function declarations, or nested blocks). We also expect @@ -425,6 +279,8 @@ struct IRBlock : IRParentInst IRBlock* getPrevBlock() { return (IRBlock*) prevInst; } IRBlock* getNextBlock() { return (IRBlock*) nextInst; } + + IRFunc* getParent() { return (IRFunc*)parent; } }; // A function parameter is represented by an instruction @@ -434,8 +290,7 @@ struct IRParam : IRInst IRParam* getNextParam(); }; -struct IRVar : IRInst -{}; +struct IRFuncType; // A function is a parent to zero or more blocks of instructions. // @@ -445,9 +300,9 @@ struct IRFunc : IRParentInst { IRFuncType* getType() { return (IRFuncType*) type.usedValue; } - IRType* getResultType() { return getType()->getResultType(); } - UInt getParamCount() { return getType()->getParamCount(); } - IRType* getParamType(UInt index) { return getType()->getParamType(index); } + IRType* getResultType(); + UInt getParamCount(); + IRType* getParamType(UInt index); IRBlock* getFirstBlock() { return (IRBlock*) firstChild; } IRBlock* getLastBlock() { return (IRBlock*) lastChild; } @@ -465,163 +320,10 @@ struct IRModule : IRParentInst IRInstID idCounter; }; -// Description of an instruction to be used for global value numbering -struct IRInstKey -{ - IRInst* inst; - - int GetHashCode(); -}; - -bool operator==(IRInstKey const& left, IRInstKey const& right); - -struct IRConstantKey -{ - IRConstant* inst; - - int GetHashCode(); -}; -bool operator==(IRConstantKey const& left, IRConstantKey const& right); - -struct SharedIRBuilder -{ - // The module that will own all of the IR - IRModule* module; - - Dictionary<IRInstKey, IRInst*> globalValueNumberingMap; - Dictionary<IRConstantKey, IRConstant*> constantMap; -}; - -struct IRBuilder -{ - // Shared state for all IR builders working on the same module - SharedIRBuilder* shared; - - IRModule* getModule() { return shared->module; } - - // The parent instruction to add children to. - IRParentInst* parentInst; - - void addInst(IRParentInst* parent, IRInst* inst); - void addInst(IRInst* inst); - - IRType* getBaseType(BaseType flavor); - IRType* getBoolType(); - IRType* getVectorType(IRType* elementType, IRValue* elementCount); - IRType* getMatrixType( - IRType* elementType, - IRValue* rowCount, - IRValue* columnCount); - IRType* getTypeType(); - IRType* getVoidType(); - IRType* getBlockType(); - - IRType* getIntrinsicType( - IROp op, - UInt argCount, - IRValue* const* args); - - IRStructDecl* createStructType(); - IRStructField* createStructField(IRType* fieldType); - - IRType* getFuncType( - UInt paramCount, - IRType* const* paramTypes, - IRType* resultType); - - IRType* getPtrType( - IRType* valueType); - - IRValue* getBoolValue(bool value); - IRValue* getIntValue(IRType* type, IRIntegerValue value); - IRValue* getFloatValue(IRType* type, IRFloatingPointValue value); - - IRInst* emitCallInst( - IRType* type, - IRValue* func, - UInt argCount, - IRValue* const* args); - - IRInst* emitIntrinsicInst( - IRType* type, - IROp op, - UInt argCount, - IRValue* const* args); - - IRInst* emitConstructorInst( - IRType* type, - UInt argCount, - IRValue* const* args); - - IRModule* createModule(); - - IRFunc* createFunc(); - - IRBlock* createBlock(); - IRBlock* emitBlock(); - - IRParam* emitParam( - IRType* type); - - IRVar* emitVar( - IRType* type); - - IRInst* emitLoad( - IRValue* ptr); - - IRInst* emitStore( - IRValue* dstPtr, - IRValue* srcVal); - - IRInst* emitFieldExtract( - IRType* type, - IRValue* base, - IRStructField* field); - - IRInst* emitFieldAddress( - IRType* type, - IRValue* basePtr, - IRStructField* field); - - IRInst* emitElementExtract( - IRType* type, - IRValue* base, - IRValue* index); - - IRInst* emitElementAddress( - IRType* type, - IRValue* basePtr, - IRValue* index); - - - IRInst* emitReturn( - IRValue* val); - - IRInst* emitReturn(); - - IRDecoration* addDecorationImpl( - IRInst* inst, - UInt decorationSize, - IRDecorationOp op); - - template<typename T> - T* addDecoration(IRInst* inst, IRDecorationOp op) - { - return (T*) addDecorationImpl(inst, sizeof(T), op); - } - - template<typename T> - T* addDecoration(IRInst* inst) - { - return (T*) addDecorationImpl(inst, sizeof(T), IRDecorationOp(T::kDecorationOp)); - } - - IRHighLevelDeclDecoration* addHighLevelDeclDecoration(IRInst* inst, Decl* decl); - IRLayoutDecoration* addLayoutDecoration(IRInst* inst, Layout* layout); -}; - -void dumpIR(IRModule* module); +void printSlangIRAssembly(StringBuilder& builder, IRModule* module); +String getSlangIRAssembly(IRModule* module); } + #endif diff --git a/source/slang/lower-to-ir.cpp b/source/slang/lower-to-ir.cpp index 10b4aefca..5ee6d5460 100644 --- a/source/slang/lower-to-ir.cpp +++ b/source/slang/lower-to-ir.cpp @@ -2,44 +2,130 @@ #include "lower-to-ir.h" #include "ir.h" +#include "ir-insts.h" #include "type-layout.h" #include "visitor.h" namespace Slang { -struct BoundMemberInfo; +// This file implements lowering of the Slang AST to a simpler SSA +// intermediate representation. +// +// IR is generated in a context (`IRGenContext`), which tracks the current +// location in the IR where code should be emitted (e.g., what basic +// block to add instructions to). Lowering a statement will emit some +// number of instructions to the context, and possibly change the +// insertion point (because of control flow). +// +// When lowering an expression we have a more interesting challenge, for +// two main reasons: +// +// 1. There might be types that are representible in the AST, but which +// we don't want to support natively in the IR. An example is a `struct` +// type with both ordinary and resource-type members; we might want to +// split values with such a type into distinct values during lowering. +// +// 2. We need to handle the difference between l-value and r-value expressions, +// and in particular the fact that HLSL/Slang supports complicated sorts +// of l-values (e.g., `someVector.zxy` is an l-value, even though it can't +// be represented by a single pointer), and also allows l-values to appear +// in multiple contexts (not just the left-hand side of assignment, but +// also as an argument to match an `out` or `in out` parameter). +// +// Our solution to both of these problems is the same. Rather than having +// the lowering of an expression return a single IR-level value (`IRInst*`), +// we have it return a more complex type (`LoweredValInfo`) which can represent +// a wider range of conceptual "values" which might correspond to multiple IR-level +// values, and/or represent a pointer to an l-value rather than the r-value itself. + +// We want to keep the representation of a `LoweringValInfo` relatively light +// - right now it is just a single pointer plus a "tag" to distinguish the cases. +// +// This means that cases that can't fit in a single pointer need a heap allocation +// to store their payload. For simplicity we represent all of these with a class +// hierarchy: +// +struct ExtendedValueInfo : RefObject +{}; -struct SubscriptInfo : RefObject +// This case is used to indicate a value that is a reference +// to an AST-level subscript declaration. +// +struct SubscriptInfo : ExtendedValueInfo { DeclRef<SubscriptDecl> declRef; }; -struct BoundSubscriptInfo : RefObject +// This case is used to indicate a reference to an AST-level +// subscript operation bound to particular arguments. +// +// For example in a case like this: +// +// RWStructuredBuffer<Foo> gBuffer; +// ... gBuffer[someIndex] ... +// +// the expression `gBuffer[someIndex]` will be lowered to +// a value that references `RWStructureBuffer<Foo>::operator[]` +// with arguments `(gBuffer, someIndex)`. +// +// Such a value can be an l-value, and depending on the context +// where it is used, can lower into a call to either the getter +// or setter operations of the subscript. +// +struct BoundSubscriptInfo : ExtendedValueInfo { DeclRef<SubscriptDecl> declRef; IRType* type; List<IRValue*> args; }; +// Some cases of `ExtendedValueInfo` need to +// recursively contain `LoweredValInfo`s, and +// so we forward declare them here and fill +// them in later. +// +struct BoundMemberInfo; +struct SwizzledLValueInfo; + + +// This type is our core representation of lowered values. +// In the simple case, it just wraps an `IRInst*`. +// More complex cases, representing l-values or aggregate +// values are also supported. struct LoweredValInfo { + // Which of the cases of value are we looking at? enum class Flavor { + // No value (akin to a null pointer) None, + + // A simple IR value Simple, + + // An l-value reprsented as an IR + // pointer to the value Ptr, + + // A member declaration bound to a particular `this` value BoundMember, + + // A reference to an AST-level subscript operation Subscript, + + // An AST-level subscript operation bound to a particular + // object and arguments. BoundSubscript, + + // The result of applying swizzling to an l-value + SwizzledLValue, }; union { IRValue* val; - BoundMemberInfo* boundMemberInfo; - SubscriptInfo* subscriptInfo; - BoundSubscriptInfo* boundSubscriptInfo; + ExtendedValueInfo* ext; }; Flavor flavor; @@ -66,33 +152,87 @@ struct LoweredValInfo } static LoweredValInfo boundMember( - LoweredValInfo const& base, - LoweredValInfo const& member); + BoundMemberInfo* boundMemberInfo); + + BoundMemberInfo* getBoundMemberInfo() + { + assert(flavor == Flavor::BoundMember); + return (BoundMemberInfo*)ext; + } static LoweredValInfo subscript( SubscriptInfo* subscriptInfo); + SubscriptInfo* getSubscriptInfo() + { + assert(flavor == Flavor::Subscript); + return (SubscriptInfo*)ext; + } + static LoweredValInfo boundSubscript( BoundSubscriptInfo* boundSubscriptInfo); + + BoundSubscriptInfo* getBoundSubscriptInfo() + { + assert(flavor == Flavor::BoundSubscript); + return (BoundSubscriptInfo*)ext; + } + + static LoweredValInfo swizzledLValue( + SwizzledLValueInfo* extInfo); + + SwizzledLValueInfo* getSwizzledLValueInfo() + { + assert(flavor == Flavor::SwizzledLValue); + return (SwizzledLValueInfo*)ext; + } }; -struct BoundMemberInfo +// Represents some declaration bound to a particular +// object. For example, if we had `obj.f` where `f` +// is a member function, we'd use a `BoundMemberInfo` +// to represnet this. +// +// Note: This case is largely avoided by special-casing +// in the handling of calls (like `obj.f(arg)`), but +// it is being left here as an example of what we might +// need/want to do in the long term. +struct BoundMemberInfo : ExtendedValueInfo { + // The base object LoweredValInfo base; - LoweredValInfo member; + + // The (AST-level) declaration reference. + DeclRef<Decl> declRef; }; -LoweredValInfo LoweredValInfo::boundMember( - LoweredValInfo const& base, - LoweredValInfo const& member) +// Represents the result of a swizzle operation in +// an l-value context. A swizzle without duplicate +// elements is allowed as an l-value, even if the +// element are non-contiguous (`.xz`) or out of +// order (`.zxy`). +// +struct SwizzledLValueInfo : ExtendedValueInfo { - BoundMemberInfo* boundMember = new BoundMemberInfo(); - boundMember->base = base; - boundMember->member = member; + // IR-level The type of the expression. + IRType* type; + + // The base expression (this should be an l-value) + LoweredValInfo base; + + // The number of elements in the swizzle + UInt elementCount; + // THe indices for the elements being swizzled + UInt elementIndices[4]; +}; + +LoweredValInfo LoweredValInfo::boundMember( + BoundMemberInfo* boundMemberInfo) +{ LoweredValInfo info; info.flavor = Flavor::BoundMember; - info.boundMemberInfo = boundMember; + info.ext = boundMemberInfo; return info; } @@ -101,7 +241,7 @@ LoweredValInfo LoweredValInfo::subscript( { LoweredValInfo info; info.flavor = Flavor::Subscript; - info.subscriptInfo = subscriptInfo; + info.ext = subscriptInfo; return info; } @@ -110,10 +250,18 @@ LoweredValInfo LoweredValInfo::boundSubscript( { LoweredValInfo info; info.flavor = Flavor::BoundSubscript; - info.boundSubscriptInfo = boundSubscriptInfo; + info.ext = boundSubscriptInfo; return info; } +LoweredValInfo LoweredValInfo::swizzledLValue( + SwizzledLValueInfo* extInfo) +{ + LoweredValInfo info; + info.flavor = Flavor::SwizzledLValue; + info.ext = extInfo; + return info; +} struct SharedIRGenContext { @@ -123,8 +271,13 @@ struct SharedIRGenContext Dictionary<DeclRef<Decl>, LoweredValInfo> declValues; - // Arrays we keep around strictly for memory-management purposes - List<RefPtr<BoundSubscriptInfo>> boundSubscripts; + // Arrays we keep around strictly for memory-management purposes: + + // Any extended values created during lowering need + // to be cleaned up after the fact. We don't try + // to reference-count these along the way because + // they need to get stored into a `union` inside `LoweredValInfo` + List<RefPtr<ExtendedValueInfo>> extValues; }; @@ -178,6 +331,29 @@ LoweredValInfo emitCallToVal( } } +LoweredValInfo emitCompoundAssignOp( + IRGenContext* context, + IRType* type, + IROp op, + UInt argCount, + IRValue* const* args) +{ + auto builder = context->irBuilder; + + assert(argCount == 2); + auto leftPtr = args[0]; + auto rightVal = args[1]; + + auto leftVal = builder->emitLoad(leftPtr); + + IRInst* innerArgs[] = { leftVal, rightVal }; + auto innerOp = builder->emitIntrinsicInst(type, op, 2, innerArgs); + + builder->emitStore(leftPtr, innerOp); + + return LoweredValInfo::ptr(leftPtr); +} + // Given a `DeclRef` for something callable, along with a bunch of // arguments, emit an appropriate call to it. LoweredValInfo emitCallToDeclRef( @@ -229,7 +405,7 @@ LoweredValInfo emitCallToDeclRef( boundSubscript->type = type; boundSubscript->args.AddRange(args, argCount); - context->shared->boundSubscripts.Add(boundSubscript); + context->shared->extValues.Add(boundSubscript); return LoweredValInfo::boundSubscript(boundSubscript); } @@ -245,6 +421,35 @@ LoweredValInfo emitCallToDeclRef( { auto op = getIntrinsicOp(funcDecl, intrinsicOpModifier); + if (Int(op) < 0) + { + switch (op) + { + case kIRPseudoOp_Pos: + return LoweredValInfo::simple(args[0]); + +#define CASE(COMPOUND, OP) \ + case COMPOUND: return emitCompoundAssignOp(context, type, OP, argCount, args) + + CASE(kIRPseudoOp_AddAssign, kIROp_Add); + CASE(kIRPseudoOp_SubAssign, kIROp_Sub); + CASE(kIRPseudoOp_MulAssign, kIROp_Mul); + CASE(kIRPseudoOp_DivAssign, kIROp_Div); + CASE(kIRPseudoOp_ModAssign, kIROp_Mod); + CASE(kIRPseudoOp_AndAssign, kIROp_BitAnd); + CASE(kIRPseudoOp_OrAssign, kIROp_BitOr); + CASE(kIRPseudoOp_XorAssign, kIROp_BitXor); + CASE(kIRPseudoOp_LshAssign, kIROp_Lsh); + CASE(kIRPseudoOp_RshAssign, kIROp_Rsh); + +#undef CASE + + default: + SLANG_UNIMPLEMENTED_X("IR pseudo-op"); + break; + } + } + return LoweredValInfo::simple(builder->emitIntrinsicInst( type, op, @@ -277,6 +482,8 @@ LoweredValInfo emitCallToDeclRef( IRValue* getSimpleVal(IRGenContext* context, LoweredValInfo lowered) { + auto builder = context->irBuilder; + top: switch(lowered.flavor) { @@ -287,12 +494,11 @@ top: return lowered.val; case LoweredValInfo::Flavor::Ptr: - return context->irBuilder->emitLoad(lowered.val); + return builder->emitLoad(lowered.val); case LoweredValInfo::Flavor::BoundSubscript: { - auto boundSubscriptInfo = lowered.boundSubscriptInfo; - auto builder = context->irBuilder; + auto boundSubscriptInfo = lowered.getBoundSubscriptInfo(); for (auto getter : getMembersOfType<GetterDecl>(boundSubscriptInfo->declRef)) { @@ -309,6 +515,17 @@ top: } break; + case LoweredValInfo::Flavor::SwizzledLValue: + { + auto swizzleInfo = lowered.getSwizzledLValueInfo(); + + return builder->emitSwizzle( + swizzleInfo->type, + getSimpleVal(context, swizzleInfo->base), + swizzleInfo->elementCount, + swizzleInfo->elementIndices); + } + default: SLANG_UNEXPECTED("unhandled value flavor"); return nullptr; @@ -397,8 +614,11 @@ IRType* lowerSimpleType( return getSimpleType(lowered); } +LoweredValInfo lowerLValueExpr( + IRGenContext* context, + Expr* expr); -LoweredValInfo lowerExpr( +LoweredValInfo lowerRValueExpr( IRGenContext* context, Expr* expr); @@ -528,6 +748,38 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower return getBuilder()->getMatrixType(irElementType, irRowCount, irColumnCount); } + + LoweredTypeInfo getArrayType( + LoweredTypeInfo const& loweredElementType, + IRValue* irElementCount) + { + switch (loweredElementType.flavor) + { + case LoweredTypeInfo::Flavor::Simple: + return getBuilder()->getArrayType( + loweredElementType.type, + irElementCount); + break; + + default: + SLANG_UNEXPECTED("array element type"); + break; + } + } + + LoweredTypeInfo visitArrayExpressionType(ArrayExpressionType* type) + { + auto loweredElementType = lowerType(context, type->BaseType); + if (auto elementCount = type->ArrayLength) + { + auto irElementCount = lowerSimpleVal(context, elementCount); + return getArrayType(loweredElementType, irElementCount); + } + else + { + return getArrayType(loweredElementType, nullptr); + } + } }; LoweredValInfo lowerVal( @@ -556,15 +808,79 @@ struct LoweringVisitor , ValVisitor<LoweringVisitor, RefPtr<Val>, RefPtr<Type>> #endif +LoweredValInfo createVar( + IRGenContext* context, + LoweredTypeInfo type, + Decl* decl = nullptr, + Layout* layout = nullptr, + IRAddressSpace addressSpace = kIRAddressSpace_Default) +{ + auto builder = context->irBuilder; + switch( type.flavor ) + { + case LoweredTypeInfo::Flavor::Simple: + { + auto irAlloc = builder->emitVar(getSimpleType(type), addressSpace); + + if (decl) + { + builder->addHighLevelDeclDecoration(irAlloc, decl); + } + + if (layout) + { + builder->addLayoutDecoration(irAlloc, layout); + } + + + return LoweredValInfo::ptr(irAlloc); + } + break; + + default: + SLANG_UNIMPLEMENTED_X("var type"); + return LoweredValInfo(); + } + +} + +void addArgs( + IRGenContext* context, + List<IRValue*>* ioArgs, + LoweredValInfo argInfo) +{ + auto& args = *ioArgs; + switch( argInfo.flavor ) + { + case LoweredValInfo::Flavor::Simple: + case LoweredValInfo::Flavor::Ptr: + case LoweredValInfo::Flavor::SwizzledLValue: + args.Add(getSimpleVal(context, argInfo)); + break; + + default: + SLANG_UNIMPLEMENTED_X("addArgs case"); + break; + } +} // -struct ExprLoweringVisitor : ExprVisitor<ExprLoweringVisitor, LoweredValInfo> +template<typename Derived> +struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> { IRGenContext* context; IRBuilder* getBuilder() { return context->irBuilder; } + // Lower an expression that should have the same l-value-ness + // as the visitor itself. + LoweredValInfo lowerSubExpr(Expr* expr) + { + return dispatch(expr); + } + + LoweredValInfo visitVarExpr(VarExpr* expr) { LoweredValInfo info = ensureDecl(context, expr->declRef); @@ -576,6 +892,83 @@ struct ExprLoweringVisitor : ExprVisitor<ExprLoweringVisitor, LoweredValInfo> SLANG_UNEXPECTED("overloaded expressions should not occur in checked AST"); } + LoweredValInfo visitIndexExpr(IndexExpr* expr) + { + auto type = lowerType(context, expr->type); + auto baseVal = lowerSubExpr(expr->BaseExpression); + auto indexVal = getSimpleVal(context, lowerRValueExpr(context, expr->IndexExpression)); + + return subscriptValue(type, baseVal, indexVal); + } + + LoweredValInfo visitMemberExpr(MemberExpr* expr) + { + auto loweredType = lowerType(context, expr->type); + auto loweredBase = lowerRValueExpr(context, expr->BaseExpression); + + auto declRef = expr->declRef; + if (auto fieldDeclRef = declRef.As<StructField>()) + { + // Okay, easy enough: we have a reference to a field of a struct type... + + auto loweredField = ensureDecl(context, fieldDeclRef); + return extractField(loweredType, loweredBase, loweredField); + } + else if (auto callableDeclRef = declRef.As<CallableDecl>()) + { + RefPtr<BoundMemberInfo> boundMemberInfo = new BoundMemberInfo(); + boundMemberInfo->base = loweredBase; + boundMemberInfo->declRef = callableDeclRef; + return LoweredValInfo::boundMember(boundMemberInfo); + } + + SLANG_UNIMPLEMENTED_X("codegen for subscript expression"); + } + + // We will always lower a dereference expression (`*ptr`) + // as an l-value, since that is the easiest way to handle it. + LoweredValInfo visitDerefExpr(DerefExpr* expr) + { + auto loweredType = lowerType(context, expr->type); + auto loweredBase = lowerRValueExpr(context, expr->base); + + // TODO: handle tupel-type for `base` + + // The type of the lowered base must by some kind of pointer, + // in order for a dereference to make senese, so we just + // need to extract the value type from that pointer here. + // + auto loweredBaseVal = getSimpleVal(context, loweredBase); + auto loweredBaseType = loweredBaseVal->getType(); + switch( loweredBaseType->op ) + { + case kIROp_PtrType: + // TODO: should we enumerate these explicitly? + case kIROp_ConstantBufferType: + case kIROp_TextureBufferType: + // Note that we do *not* perform an actual `load` operation + // here, but rather just use the pointer value to construct + // an appropriate `LoweredValInfo` representing the underlying + // dereference. + // + // This is important so that an expression like `&((*foo).bar)` + // (which is desugared from `&foo->bar`) can be handled; such + // an expression does *not* perform a dereference at runtime, + // and is just a bit of pointer math. + // + return LoweredValInfo::ptr(loweredBaseVal); + + default: + SLANG_UNIMPLEMENTED_X("codegen for deref expression"); + return LoweredValInfo(); + } + } + + LoweredValInfo visitParenExpr(ParenExpr* expr) + { + return lowerSubExpr(expr->base); + } + LoweredValInfo visitInitializerListExpr(InitializerListExpr* expr) { SLANG_UNIMPLEMENTED_X("codegen for initializer list expression"); @@ -605,23 +998,6 @@ struct ExprLoweringVisitor : ExprVisitor<ExprLoweringVisitor, LoweredValInfo> SLANG_UNIMPLEMENTED_X("codegen for aggregate type constructor expression"); } - void addArgs(List<IRValue*>* ioArgs, LoweredValInfo argInfo) - { - auto& args = *ioArgs; - switch( argInfo.flavor ) - { - case LoweredValInfo::Flavor::Simple: - case LoweredValInfo::Flavor::Ptr: - args.Add(getSimpleVal(context, argInfo)); - break; - - default: - SLANG_UNIMPLEMENTED_X("addArgs case"); - break; - } - } - - // Add arguments that appeared directly in an argument list // to the list of argument values for a call. void addDirectCallArgs( @@ -632,45 +1008,129 @@ struct ExprLoweringVisitor : ExprVisitor<ExprLoweringVisitor, LoweredValInfo> for( auto arg : expr->Arguments ) { - auto loweredArg = lowerExpr(context, arg); - addArgs(&irArgs, loweredArg); + // TODO: Need to handle case of l-value arguments, + // when they are matched to `out` or `in out` parameters. + auto loweredArg = lowerRValueExpr(context, arg); + addArgs(context, ioArgs, loweredArg); } } - // Try to add "all" the arguments for a call to the argument list, - // including implicit arguments that come from (e.g.,) a member - // expression used to form the call. - void addCallArgs( - InvokeExpr* expr, - List<IRValue*>* ioArgs) - { - auto& irArgs = *ioArgs; + // After a call to a function with `out` or `in out` + // parameters, we may need to copy data back into + // the l-value locations used for output arguments. + // + // During lowering of the argument list, we build + // up a list of these "fixup" assignments that need + // to be performed. + struct OutArgumentFixup + { + LoweredValInfo dst; + LoweredValInfo src; + }; - // TODO: should unwrap any layers of identity expressions around this... - if( auto baseMemberExpr = expr->FunctionExpr.As<MemberExpr>() ) + void addDirectCallArgs( + InvokeExpr* expr, + DeclRef<CallableDecl> funcDeclRef, + List<IRValue*>* ioArgs, + List<OutArgumentFixup>* ioFixups) + { + auto funcDecl = funcDeclRef.getDecl(); + auto& args = expr->Arguments; + UInt argCount = expr->Arguments.Count(); + UInt argIndex = 0; + for (auto paramDeclRef : getMembersOfType<ParamDecl>(funcDeclRef)) { - // This call took the form of a member function call, so - // we need to correctly add the `this` argument as - // an explicit argument. - // - auto loweredBase = lowerExpr(context, baseMemberExpr->BaseExpression); - addArgs(&irArgs, loweredBase); - } + if (argIndex >= argCount) + { + // The remaining parameters must be defaulted... + break; + } - addDirectCallArgs(expr, ioArgs); - } + auto paramDecl = paramDeclRef.getDecl(); + auto paramType = lowerType(context, GetType(paramDeclRef)); + auto argExpr = expr->Arguments[argIndex++]; - LoweredValInfo lowerIntrinsicCall( - InvokeExpr* expr, - IROp intrinsicOp) - { - auto type = lowerSimpleType(context, expr->type); + if (paramDecl->HasModifier<OutModifier>() + || paramDecl->HasModifier<InOutModifier>()) + { + // This is a `out` or `inout` parameter, and so + // the argument must be lowered as an l-value. + + LoweredValInfo loweredArg = lowerLValueExpr(context, argExpr); + + // According to our "calling convention" we need to + // pass a pointer into the callee. + // + // A naive approach would be to just take the address + // of `loweredArg` above and pass it in, but that + // has two issues: + // + // 1. The l-value might not be something that has a single + // well-defined "address" (e.g., `foo.xzy`). + // + // 2. The l-value argument might actually alias some other + // storage that the callee will access (e.g., we are + // passing in a global variable, or two `out` parameters + // are being passed the same location in an array). + // + // In each of these cases, the safe option is to create + // a temporary variable to use for argument-passing, + // and then do copy-in/copy-out around the call. + + LoweredValInfo tempVar = createVar(context, paramType); + + // If the parameter is `in out` or `inout`, then we need + // to ensure that we pass in the original value stored + // in the argument, which we accomplish by assigning + // from the l-value to our temp. + if (paramDecl->HasModifier<InModifier>() + || paramDecl->HasModifier<InOutModifier>()) + { + assign(context, tempVar, loweredArg); + } - List<IRValue*> irArgs; - addCallArgs(expr, &irArgs); - UInt argCount = irArgs.Count(); + // Now we can pass the address of the temporary variable + // to the callee as the actual argument for the `in out` + assert(tempVar.flavor == LoweredValInfo::Flavor::Ptr); + (*ioArgs).Add(tempVar.val); - return LoweredValInfo::simple(getBuilder()->emitIntrinsicInst(type, intrinsicOp, argCount, &irArgs[0])); + // Finally, after the call we will need + // to copy in the other direction: from our + // temp back to the original l-value. + OutArgumentFixup fixup; + fixup.src = tempVar; + fixup.dst = loweredArg; + + (*ioFixups).Add(fixup); + + } + else + { + // This is a pure input parameter, and so we will + // pass it as an r-value. + LoweredValInfo loweredArg = lowerRValueExpr(context, argExpr); + addArgs(context, ioArgs, loweredArg); + } + } + } + + // Add arguments that appeared directly in an argument list + // to the list of argument values for a call. + void addDirectCallArgs( + InvokeExpr* expr, + DeclRef<Decl> funcDeclRef, + List<IRValue*>* ioArgs, + List<OutArgumentFixup>* ioFixups) + { + if (auto callableDeclRef = funcDeclRef.As<CallableDecl>()) + { + addDirectCallArgs(expr, callableDeclRef, ioArgs, ioFixups); + } + else + { + SLANG_UNEXPECTED("shouldn't relaly happen"); + addDirectCallArgs(expr, ioArgs); + } } void addFuncBaseArgs( @@ -684,9 +1144,12 @@ struct ExprLoweringVisitor : ExprVisitor<ExprLoweringVisitor, LoweredValInfo> } } - LoweredValInfo lowerSimpleCall(InvokeExpr* expr) + void applyOutArgumentFixups(List<OutArgumentFixup> const& fixups) { - + for (auto fixup : fixups) + { + assign(context, fixup.dst, fixup.src); + } } LoweredValInfo visitInvokeExpr(InvokeExpr* expr) @@ -711,40 +1174,60 @@ struct ExprLoweringVisitor : ExprVisitor<ExprLoweringVisitor, LoweredValInfo> // arguments that will be part of the call. List<IRValue*> irArgs; + // We will also collect "fixup" actions that need + // to be performed after teh call, in order to + // copy the final values for `out` parameters + // back to their arguments. + List<OutArgumentFixup> argFixups; auto funcExpr = expr->FunctionExpr; if (auto memberFuncExpr = funcExpr.As<MemberExpr>()) { - auto loweredBaseVal = lowerExpr(context, memberFuncExpr->BaseExpression); - addArgs(&irArgs, loweredBaseVal); + auto loweredBaseVal = lowerRValueExpr(context, memberFuncExpr->BaseExpression); + addArgs(context, &irArgs, loweredBaseVal); auto funcDeclRef = memberFuncExpr->declRef; - addDirectCallArgs(expr, &irArgs); - return emitCallToDeclRef(context, type, funcDeclRef, irArgs); + addDirectCallArgs(expr, funcDeclRef, &irArgs, &argFixups); + auto result = emitCallToDeclRef(context, type, funcDeclRef, irArgs); + applyOutArgumentFixups(argFixups); + return result; } else if (auto staticMemberFuncExpr = funcExpr.As<StaticMemberExpr>()) { auto funcDeclRef = staticMemberFuncExpr->declRef; - addDirectCallArgs(expr, &irArgs); - return emitCallToDeclRef(context, type, funcDeclRef, irArgs); + addDirectCallArgs(expr, funcDeclRef, &irArgs, &argFixups); + auto result = emitCallToDeclRef(context, type, funcDeclRef, irArgs); + applyOutArgumentFixups(argFixups); + return result; } else if (auto varExpr = funcExpr.As<VarExpr>()) { auto funcDeclRef = varExpr->declRef; - addDirectCallArgs(expr, &irArgs); - return emitCallToDeclRef(context, type, funcDeclRef, irArgs); + addDirectCallArgs(expr, funcDeclRef, &irArgs, &argFixups); + auto result = emitCallToDeclRef(context, type, funcDeclRef, irArgs); + applyOutArgumentFixups(argFixups); + return result; } // The default case is to assume that we just have // an ordinary expression, and can lower it as such. - LoweredValInfo funcVal = lowerExpr(context, expr->FunctionExpr); + LoweredValInfo funcVal = lowerRValueExpr(context, expr->FunctionExpr); // Now we add any direct arguments from the call expression itself. addDirectCallArgs(expr, &irArgs); // Delegate to the logic for invoking a value. - return emitCallToVal(context, type, funcVal, irArgs.Count(), irArgs.Buffer()); + auto result = emitCallToVal(context, type, funcVal, irArgs.Count(), irArgs.Buffer()); + + // TODO: because of the nature of how the `emitCallToVal` case works + // right now, we don't have information on in/out parameters, and + // so we can't collect info to apply fixups. + // + // Once we have a better representation for function types, though, + // this should be fixable. + + return result; } LoweredValInfo subscriptValue( @@ -776,15 +1259,6 @@ struct ExprLoweringVisitor : ExprVisitor<ExprLoweringVisitor, LoweredValInfo> } - LoweredValInfo visitIndexExpr(IndexExpr* expr) - { - auto type = lowerType(context, expr->type); - auto baseVal = lowerExpr(context, expr->BaseExpression); - auto indexVal = getSimpleVal(context, lowerExpr(context, expr->IndexExpression)); - - return subscriptValue(type, baseVal, indexVal); - } - LoweredValInfo extractField( LoweredTypeInfo fieldType, LoweredValInfo base, @@ -824,70 +1298,6 @@ struct ExprLoweringVisitor : ExprVisitor<ExprLoweringVisitor, LoweredValInfo> return ensureDecl(context, expr->declRef); } - LoweredValInfo visitMemberExpr(MemberExpr* expr) - { - auto loweredType = lowerType(context, expr->type); - auto loweredBase = lowerExpr(context, expr->BaseExpression); - - auto declRef = expr->declRef; - if (auto fieldDeclRef = declRef.As<StructField>()) - { - // Okay, easy enough: we have a reference to a field of a struct type... - - auto loweredField = ensureDecl(context, fieldDeclRef); - return extractField(loweredType, loweredBase, loweredField); - } - else if (auto callableDeclRef = declRef.As<CallableDecl>()) - { - auto loweredFunc = ensureDecl(context, callableDeclRef); - return LoweredValInfo::boundMember(loweredBase, loweredFunc); - } - - SLANG_UNIMPLEMENTED_X("codegen for subscript expression"); - } - - LoweredValInfo visitSwizzleExpr(SwizzleExpr* expr) - { - SLANG_UNIMPLEMENTED_X("codegen for swizzle expression"); - } - - LoweredValInfo visitDerefExpr(DerefExpr* expr) - { - auto loweredType = lowerType(context, expr->type); - auto loweredBase = lowerExpr(context, expr->base); - - // TODO: handle tupel-type for `base` - - // The type of the lowered base must by some kind of pointer, - // in order for a dereference to make senese, so we just - // need to extract the value type from that pointer here. - // - auto loweredBaseVal = getSimpleVal(context, loweredBase); - auto loweredBaseType = loweredBaseVal->getType(); - switch( loweredBaseType->op ) - { - case kIROp_PtrType: - // TODO: should we enumerate these explicitly? - case kIROp_ConstantBufferType: - case kIROp_TextureBufferType: - // Note that we do *not* perform an actual `load` operation - // here, but rather just use the pointer value to construct - // an appropriate `LoweredValInfo` representing the underlying - // dereference. - // - // This is important so that an expression like `&((*foo).bar)` - // (which is desugared from `&foo->bar`) can be handled; such - // an expression does *not* perform a dereference at runtime, - // and is just a bit of pointer math. - // - return LoweredValInfo::ptr(loweredBaseVal); - - default: - SLANG_UNIMPLEMENTED_X("codegen for deref expression"); - return LoweredValInfo(); - } - } - LoweredValInfo visitTypeCastExpr(TypeCastExpr* expr) { SLANG_UNIMPLEMENTED_X("codegen for type cast expression"); @@ -916,8 +1326,8 @@ struct ExprLoweringVisitor : ExprVisitor<ExprLoweringVisitor, LoweredValInfo> // and right-hand sides, and then peform an assignment // based on the resulting values. // - auto leftVal = lowerExpr(context, expr->left); - auto rightVal = lowerExpr(context, expr->right); + auto leftVal = lowerLValueExpr(context, expr->left); + auto rightVal = lowerRValueExpr(context, expr->right); assign(context, leftVal, rightVal); // The result value of the assignment expression is @@ -925,18 +1335,80 @@ struct ExprLoweringVisitor : ExprVisitor<ExprLoweringVisitor, LoweredValInfo> // to be an l-value). return leftVal; } +}; - LoweredValInfo visitParenExpr(ParenExpr* expr) +struct LValueExprLoweringVisitor : ExprLoweringVisitorBase<LValueExprLoweringVisitor> +{ + // When visiting a swizzle expression in an l-value context, + // we need to construct a "sizzled l-value." + LoweredValInfo visitSwizzleExpr(SwizzleExpr* expr) { - return lowerExpr(context, expr->base); + auto irType = lowerSimpleType(context, expr->type); + auto loweredBase = lowerRValueExpr(context, expr->base); + + RefPtr<SwizzledLValueInfo> swizzledLValue = new SwizzledLValueInfo(); + swizzledLValue->type = irType; + swizzledLValue->base = loweredBase; + + UInt elementCount = (UInt)expr->elementCount; + swizzledLValue->elementCount = elementCount; + for (UInt ii = 0; ii < elementCount; ++ii) + { + swizzledLValue->elementIndices[ii] = (UInt) expr->elementIndices[ii]; + } + + context->shared->extValues.Add(swizzledLValue); + return LoweredValInfo::swizzledLValue(swizzledLValue); + } + +}; + +struct RValueExprLoweringVisitor : ExprLoweringVisitorBase<RValueExprLoweringVisitor> +{ + // A swizzle in an r-value context can save time by just + // emitting the swizzle instuctions directly. + LoweredValInfo visitSwizzleExpr(SwizzleExpr* expr) + { + auto irType = lowerSimpleType(context, expr->type); + auto irBase = getSimpleVal(context, lowerRValueExpr(context, expr->base)); + + auto builder = getBuilder(); + + auto irIntType = builder->getBaseType(BaseType::Int); + + UInt elementCount = (UInt)expr->elementCount; + IRValue* irElementIndices[4]; + for (UInt ii = 0; ii < elementCount; ++ii) + { + irElementIndices[ii] = builder->getIntValue( + irIntType, + (IRIntegerValue)expr->elementIndices[ii]); + } + + auto irSwizzle = builder->emitSwizzle( + irType, + irBase, + elementCount, + &irElementIndices[0]); + + return LoweredValInfo::simple(irSwizzle); } }; -LoweredValInfo lowerExpr( +LoweredValInfo lowerLValueExpr( + IRGenContext* context, + Expr* expr) +{ + LValueExprLoweringVisitor visitor; + visitor.context = context; + return visitor.dispatch(expr); +} + +LoweredValInfo lowerRValueExpr( IRGenContext* context, Expr* expr) { - ExprLoweringVisitor visitor; + RValueExprLoweringVisitor visitor; visitor.context = context; return visitor.dispatch(expr); } @@ -952,6 +1424,185 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor> SLANG_UNIMPLEMENTED_X("stmt catch-all"); } + // Create a basic block in the current function, + // so that it can be used for a label. + IRBlock* createBlock() + { + return getBuilder()->createBlock(); + } + + // Insert a block at the current location (ending + // the previous block with an unconditional jump + // if needed). + void insertBlock(IRBlock* block) + { + auto builder = getBuilder(); + auto parent = builder->parentInst; + + IRBlock* prevBlock = nullptr; + IRFunc* parentFunc = nullptr; + + switch (parent->op) + { + case kIROp_Block: + prevBlock = (IRBlock*)parent; + parentFunc = prevBlock->getParent(); + break; + + default: + SLANG_UNEXPECTED("bad parent kind for block"); + return; + } + + // If the previous block doesn't already have + // a terminator instruction, then be sure to + // emit a branch to the new block. + if (!isTerminatorInst(prevBlock->lastChild)) + { + builder->emitBranch(block); + } + + builder->parentInst = parentFunc; + builder->addInst(block); + builder->parentInst = block; + } + + // Start a new block at the current location. + // This is just the composition of `createBlock` + // and `insertBlock`. + IRBlock* startBlock() + { + auto block = createBlock(); + insertBlock(block); + return block; + } + + void visitIfStmt(IfStmt* stmt) + { + auto builder = getBuilder(); + + auto condExpr = stmt->Predicate; + auto thenStmt = stmt->PositiveStatement; + auto elseStmt = stmt->NegativeStatement; + + auto irCond = getSimpleVal(context, + lowerRValueExpr(context, condExpr)); + + if (elseStmt) + { + auto thenBlock = createBlock(); + auto elseBlock = createBlock(); + auto afterBlock = createBlock(); + + builder->emitIfElse(irCond, thenBlock, elseBlock, afterBlock); + + insertBlock(thenBlock); + lowerStmt(context, thenStmt); + builder->emitBranch(afterBlock); + + insertBlock(elseBlock); + lowerStmt(context, elseStmt); + + insertBlock(afterBlock); + } + else + { + auto thenBlock = createBlock(); + auto afterBlock = createBlock(); + + builder->emitIf(irCond, thenBlock, afterBlock); + + insertBlock(thenBlock); + lowerStmt(context, thenStmt); + + insertBlock(afterBlock); + } + } + + void addLoopDecorations( + IRInst* inst, + Stmt* stmt) + { + for(auto attr : stmt->GetModifiersOfType<HLSLUncheckedAttribute>()) + { + // TODO: We should actually catch these attributes during + // semantic checking, so that they have a strongly-typed + // representation in the AST. + if(getText(attr->getName()) == "unroll") + { + auto decoration = getBuilder()->addDecoration<IRLoopControlDecoration>(inst); + decoration->mode = kIRLoopControl_Unroll; + } + } + } + + void visitForStmt(ForStmt* stmt) + { + auto builder = getBuilder(); + + // The initializer clause for the statement + // can always safetly be emitted to the current block. + if (auto initStmt = stmt->InitialStatement) + { + lowerStmt(context, initStmt); + } + + // We will create blocks for the various places + // we need to jump to inside the control flow, + // including the blocks that will be referenced + // by `continue` or `break` statements. + auto loopHead = createBlock(); + auto bodyLabel = createBlock(); + auto breakLabel = createBlock(); + auto continueLabel = createBlock(); + + // TODO: register `loopHead` as the target for a + // `continue` statement. + + // Emit the branch that will start out loop, + // and then insert the block for the head. + + auto loopInst = builder->emitLoop( + loopHead, + breakLabel, + continueLabel); + + addLoopDecorations(loopInst, stmt); + + insertBlock(loopHead); + + // Now that we are within the header block, we + // want to emit the expression for the loop condition: + if (auto condExpr = stmt->PredicateExpression) + { + auto irCondition = getSimpleVal(context, + lowerRValueExpr(context, stmt->PredicateExpression)); + + // Now we want to `break` if the loop condition is false. + builder->emitLoopTest( + irCondition, + bodyLabel, + breakLabel); + } + + // Emit the body of the loop + insertBlock(bodyLabel); + lowerStmt(context, stmt->Statement); + + // Insert the `continue` block + insertBlock(continueLabel); + if (auto incrExpr = stmt->SideEffectExpression) + { + lowerRValueExpr(context, incrExpr); + } + + // At the end of the body we need to jump back to the top. + builder->emitBranch(loopHead); + + // Finally we insert the label that a `break` will jump to + insertBlock(breakLabel); + } + void visitExpressionStmt(ExpressionStmt* stmt) { // The statement evaluates an expression @@ -960,7 +1611,11 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor> // lower the expression, and don't use // the result. // - lowerExpr(context, stmt->Expression); + // Note that we lower using the l-value path, + // so that an expression statement that names + // a location (but doesn't load from it) + // will not actually emit a load. + lowerLValueExpr(context, stmt->Expression); } void visitDeclStmt(DeclStmt* stmt) @@ -1004,7 +1659,7 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor> // a value first, and then emit the resulting value. if( auto expr = stmt->Expression ) { - auto loweredExpr = lowerExpr(context, expr); + auto loweredExpr = lowerRValueExpr(context, expr); getBuilder()->emitReturn(getSimpleVal(context, loweredExpr)); } @@ -1026,9 +1681,15 @@ void lowerStmt( void assign( IRGenContext* context, - LoweredValInfo const& left, - LoweredValInfo const& right) + LoweredValInfo const& inLeft, + LoweredValInfo const& inRight) { + LoweredValInfo left = inLeft; + LoweredValInfo right = inRight; + + auto builder = context->irBuilder; + +top: switch (left.flavor) { case LoweredValInfo::Flavor::Ptr: @@ -1036,8 +1697,8 @@ void assign( { case LoweredValInfo::Flavor::Simple: case LoweredValInfo::Flavor::Ptr: + case LoweredValInfo::Flavor::SwizzledLValue: { - auto builder = context->irBuilder; builder->emitStore( left.val, getSimpleVal(context, right)); @@ -1050,6 +1711,90 @@ void assign( } break; + case LoweredValInfo::Flavor::SwizzledLValue: + { + // The `left` value is of the form `<someLValue>.<swizzleElements>`. + // + // We could conceivably define a custom "swizzled store" instruction + // that would handle the common case where the base l-value is + // a simple lvalue (`LowerdValInfo::Flavor::Ptr`): + // + // float4 foo; + // foo.zxy = float3(...); + // + // However, this doesn't handle complex cases like the following: + // + // RWStructureBuffer<float4> foo; + // ... + // foo[index].xzy = float3(...); + // + // In a case like that, we really need to lower through a temp: + // + // float4 tmp = foo[index]; + // tmp.xzy = float3(...); + // foo[index] = tmp; + // + // We want to handle the general case, we we might as well + // try to handle everything uniformly. + // + auto swizzleInfo = left.getSwizzledLValueInfo(); + auto type = swizzleInfo->type; + auto loweredBase = swizzleInfo->base; + + // Load from the base value: + IRInst* irLeftVal = getSimpleVal(context, loweredBase); + auto irRightVal = getSimpleVal(context, right); + + // Now apply the swizzle + IRInst* irSwizzled = builder->emitSwizzleSet( + irLeftVal->getType(), + irLeftVal, + irRightVal, + swizzleInfo->elementCount, + swizzleInfo->elementIndices); + + // And finally, store the value back where we got it. + // + // Note: this is effectively a recursive call to + // `assign()`, so we do a simple tail-recursive call here. + left = loweredBase; + right = LoweredValInfo::simple(irSwizzled); + goto top; + } + break; + + case LoweredValInfo::Flavor::BoundSubscript: + { + // The `left` value refers to a subscript operation on + // a resource type, bound to particular arguments, e.g.: + // `someStructuredBuffer[index]`. + // + // When storing to such a value, we need to emit a call + // to the appropriate builtin "setter" accessor. + auto subscriptInfo = left.getBoundSubscriptInfo(); + auto type = subscriptInfo->type; + + // Search for an appropriate "setter" declaration + for (auto setterDeclRef : getMembersOfType<SetterDecl>(subscriptInfo->declRef)) + { + auto allArgs = subscriptInfo->args; + + addArgs(context, &allArgs, right); + + emitCallToDeclRef( + context, + builder->getVoidType(), + setterDeclRef, + allArgs); + return; + } + + // No setter found? Then we have an error! + SLANG_UNEXPECTED("no setter found"); + break; + } + break; + default: SLANG_UNIMPLEMENTED_X("assignment"); break; @@ -1120,33 +1865,44 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> auto varType = lowerType(context, decl->getType()); - LoweredValInfo varVal; + // TODO: If the variable is marked `static` then we need to + // deal with it specially: we should move its allocation out + // to the global scope, and then we have to deal with its + // initializer expression a bit carefully (it should only + // be initialized on-demand at its first use). + + // Some qualifiers on a variable will change how we allocate it, + // so we need to reflect that somehow. The first example + // we run into is the `groupshared` qualifier, which marks + // a variable in a compute shader as having per-group allocation + // rather than the traditional per-thread (or rather per-thread + // per-activation-record) allocation. + // + // Options include: + // + // - Use a distinct allocation opration, so that the type + // of the variable address/value is unchanged. + // + // - Add a notion of an "address space" to pointer types, + // so that we can allocate things in distinct spaces. + // + // - Add a notion of a "rate" so that we can declare a + // variable with a distinct rate. + // + // For now we might do the expedient thing and handle this + // via a notion of an "address space." - switch( varType.flavor ) + IRAddressSpace addressSpace = kIRAddressSpace_Default; + if (decl->HasModifier<HLSLGroupSharedModifier>()) { - case LoweredTypeInfo::Flavor::Simple: - { - auto irAlloc = getBuilder()->emitVar(getSimpleType(varType)); - - getBuilder()->addHighLevelDeclDecoration(irAlloc, decl); - - if (getLayout()) - { - getBuilder()->addLayoutDecoration(irAlloc, getLayout()); - } - - - varVal = LoweredValInfo::ptr(irAlloc); - } - break; - - default: - SLANG_UNIMPLEMENTED_X("struct field type"); + addressSpace = kIRAddressSpace_GroupShared; } + LoweredValInfo varVal = createVar(context, varType, decl, getLayout(), addressSpace); + if( auto initExpr = decl->initExpr ) { - auto initVal = lowerExpr(context, initExpr); + auto initVal = lowerRValueExpr(context, initExpr); assign(context, varVal, initVal); } @@ -1241,6 +1997,8 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> IRParam* irParam = subBuilder->emitParam(irParamType); + subBuilder->addHighLevelDeclDecoration(irParam, paramDecl); + DeclRef<ParamDecl> paramDeclRef = makeDeclRef(paramDecl.Ptr()); LoweredValInfo irParamVal = LoweredValInfo::simple(irParam); @@ -1278,6 +2036,21 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> lowerStmt(subContext, decl->Body); + // We need to carefully add a terminator instruction to the end + // of the body, in case the user didn't do so. + if (!isTerminatorInst(subContext->irBuilder->parentInst->lastChild)) + { + if (irResultType->op == kIROp_VoidType) + { + subContext->irBuilder->emitReturn(); + } + else + { + SLANG_UNEXPECTED("Needed a return here"); + subContext->irBuilder->emitReturn(); + } + } + getBuilder()->addHighLevelDeclDecoration(irFunc, decl); getBuilder()->addInst(irFunc); @@ -1365,11 +2138,48 @@ static void lowerEntryPointToIR( EntryPointRequest* entryPointRequest, EntryPointLayout* entryPointLayout) { - auto entryPointFunc = entryPointLayout->entryPoint; + // First, lower the entry point like an ordinary function + auto entryPointFuncDecl = entryPointLayout->entryPoint; + auto loweredEntryPointFunc = lowerDecl(context, entryPointFuncDecl, entryPointLayout); + auto irFunc = getSimpleVal(context, loweredEntryPointFunc); + + auto builder = context->irBuilder; + + // We are going to attach all the entry-point-specific information + // to the declaration as meta-data decorations for now. + // + // I'm not convinced this is the right way to go, but it is + // the easiest and most expedient thing. + // + auto profile = entryPointRequest->profile; + auto stage = profile.GetStage(); - // TODO: entry point lowering is probably *not* just like lowering a function... + auto entryPointDecoration = builder->addDecoration<IREntryPointDecoration>(irFunc); + entryPointDecoration->profile = profile; - lowerDecl(context, entryPointFunc, entryPointLayout); + // Next, we need to start attaching the meta-data that is + // required based on the particular stage we are targetting: + switch (stage) + { + case Stage::Compute: + { + // We need to attach information about the thread group size here. + auto threadGroupSizeDecoration = builder->addDecoration<IRComputeThreadGroupSizeDecoration>(irFunc); + static const UInt kAxisCount = 3; + + // TODO: this is kind of gross because we are using a public + // reflection API function, rather than some kind of internal + // utility it forwards to... + spReflectionEntryPoint_getComputeThreadGroupSize( + (SlangReflectionEntryPoint*)entryPointLayout, + kAxisCount, + &threadGroupSizeDecoration->sizeAlongAxis[0]); + } + break; + + default: + break; + } } IRModule* lowerEntryPointToIR( @@ -1412,4 +2222,17 @@ IRModule* lowerEntryPointToIR( } +String emitSlangIRAssemblyForEntryPoint( + EntryPointRequest* entryPoint) +{ + auto compileRequest = entryPoint->compileRequest; + auto irModule = lowerEntryPointToIR( + entryPoint, + compileRequest->layout.Ptr(), + // TODO: we need to pick the target more carefully here + CodeGenTarget::HLSL); + + return getSlangIRAssembly(irModule); +} + } diff --git a/source/slang/options.cpp b/source/slang/options.cpp index 669ab1827..a87692700 100644 --- a/source/slang/options.cpp +++ b/source/slang/options.cpp @@ -300,6 +300,8 @@ struct OptionsParser CASE(spirv, SPIRV); CASE(spirv-assembly, SPIRV_ASM); + CASE(slang-ir, IR); + CASE(slang-ir-assembly, IR_ASM); CASE(none, TARGET_NONE); #undef CASE diff --git a/source/slang/profile-defs.h b/source/slang/profile-defs.h index 84153ee46..513ba3078 100644 --- a/source/slang/profile-defs.h +++ b/source/slang/profile-defs.h @@ -47,6 +47,8 @@ LANGUAGE(GLSL_ES, glsl_es) LANGUAGE(GLSL_VK, glsl_vk) LANGUAGE(SPIRV, spirv) LANGUAGE(SPIRV_GL, spirv_gl) +LANGUAGE(SlangIR, slang_ir) +LANGUAGE(SlangIRAssembly, slang_ir_assembly) LANGUAGE_ALIAS(GLSL, glsl_gl) LANGUAGE_ALIAS(SPIRV, spirv_vk) diff --git a/source/slang/slang.vcxproj b/source/slang/slang.vcxproj index 1f55138e4..08a448a6d 100644 --- a/source/slang/slang.vcxproj +++ b/source/slang/slang.vcxproj @@ -172,6 +172,7 @@ <ClInclude Include="emit.h" /> <ClInclude Include="expr-defs.h" /> <ClInclude Include="ir-inst-defs.h" /> + <ClInclude Include="ir-insts.h" /> <ClInclude Include="ir.h" /> <ClInclude Include="lexer.h" /> <ClInclude Include="lookup.h" /> diff --git a/source/slang/slang.vcxproj.filters b/source/slang/slang.vcxproj.filters index 9a85ce966..111b23d36 100644 --- a/source/slang/slang.vcxproj.filters +++ b/source/slang/slang.vcxproj.filters @@ -39,6 +39,7 @@ <ClInclude Include="ir.h" /> <ClInclude Include="lower-to-ir.h" /> <ClInclude Include="ir-inst-defs.h" /> + <ClInclude Include="ir-insts.h" /> </ItemGroup> <ItemGroup> <ClCompile Include="check.cpp" /> @@ -65,8 +66,8 @@ <ClCompile Include="lower-to-ir.cpp" /> </ItemGroup> <ItemGroup> - <None Include="core.meta.slang" /> - <None Include="glsl.meta.slang" /> - <None Include="hlsl.meta.slang" /> + <CustomBuild Include="core.meta.slang" /> + <CustomBuild Include="glsl.meta.slang" /> + <CustomBuild Include="hlsl.meta.slang" /> </ItemGroup> </Project>
\ No newline at end of file diff --git a/source/slang/type-layout.cpp b/source/slang/type-layout.cpp index 00a625427..ec71df37d 100644 --- a/source/slang/type-layout.cpp +++ b/source/slang/type-layout.cpp @@ -621,6 +621,8 @@ LayoutRulesFamilyImpl* GetLayoutRulesFamilyImpl(CodeGenTarget target) case CodeGenTarget::HLSL: case CodeGenTarget::DXBytecode: case CodeGenTarget::DXBytecodeAssembly: + case CodeGenTarget::SlangIR: + case CodeGenTarget::SlangIRAssembly: return &kHLSLLayoutRulesFamilyImpl; case CodeGenTarget::GLSL: |
