summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorTim Foley <tfoleyNV@users.noreply.github.com>2017-09-14 15:37:05 -0700
committerGitHub <noreply@github.com>2017-09-14 15:37:05 -0700
commit10b62eecd94be53eca4ac2555af860f864966d76 (patch)
tree9a140acfda0e3f0755f2c120870c72d5a8f4b232
parent8cdfce564546c03c2c1ce179561591276aeb23a8 (diff)
IR: handle control flow constructs (#186)
* IR: handle control flow constructs This change includes a bunch of fixes and additions to the IR path: - `slang-ir-assembly` is now a valid output target (so we can use it for testing) - This uses what used to be the IR "dumping" logic, revamped to support much prettier output. - A future change will need to add back support for less prettified output to use when actually debugging - IR generation for `for` loops and `if` statements is supported - HLSL output from the above control flow constructs is implemented - Revamped the handling of l-values, and in particular work on compound ops like `+=` - Add basic IR support for `groupshared` variables - Add basic IR support for storing compute thread-group size - Output semantics on entry point parameters - This uses the AST structures to find semantics, so its still needs work - Pass through loop unroll flags - This is required to match `fxc` output, at least until we implement unrolling ourselves. * Fixup: 64-bit build issues. * fixup for merge
-rw-r--r--slang.h2
-rw-r--r--source/core/slang-string.cpp14
-rw-r--r--source/core/slang-string.h7
-rw-r--r--source/slang/compiler.cpp52
-rw-r--r--source/slang/compiler.h2
-rw-r--r--source/slang/emit.cpp438
-rw-r--r--source/slang/hlsl.meta.slang8
-rw-r--r--source/slang/hlsl.meta.slang.cpp9
-rw-r--r--source/slang/ir-inst-defs.h82
-rw-r--r--source/slang/ir-insts.h569
-rw-r--r--source/slang/ir.cpp644
-rw-r--r--source/slang/ir.h326
-rw-r--r--source/slang/lower-to-ir.cpp1211
-rw-r--r--source/slang/options.cpp2
-rw-r--r--source/slang/profile-defs.h2
-rw-r--r--source/slang/slang.vcxproj1
-rw-r--r--source/slang/slang.vcxproj.filters7
-rw-r--r--source/slang/type-layout.cpp2
-rw-r--r--tests/ir/loop.slang33
-rw-r--r--tests/ir/loop.slang.expected67
20 files changed, 2869 insertions, 609 deletions
diff --git a/slang.h b/slang.h
index e9b817dbb..161e703f5 100644
--- a/slang.h
+++ b/slang.h
@@ -76,6 +76,8 @@ extern "C"
SLANG_DXBC,
SLANG_DXBC_ASM,
SLANG_REFLECTION_JSON,
+ SLANG_IR,
+ SLANG_IR_ASM,
};
typedef int SlangPassThrough;
diff --git a/source/core/slang-string.cpp b/source/core/slang-string.cpp
index 9bc9e3a54..5a3d8e4f9 100644
--- a/source/core/slang-string.cpp
+++ b/source/core/slang-string.cpp
@@ -266,7 +266,7 @@ namespace Slang
append(slice.begin(), slice.end());
}
- void String::append(int value, int radix)
+ void String::append(int32_t value, int radix)
{
enum { kCount = 33 };
char* data = prepareForAppend(kCount);
@@ -275,7 +275,7 @@ namespace Slang
buffer->length += count;
}
- void String::append(unsigned int value, int radix)
+ void String::append(uint32_t value, int radix)
{
enum { kCount = 33 };
char* data = prepareForAppend(kCount);
@@ -284,7 +284,7 @@ namespace Slang
buffer->length += count;
}
- void String::append(long long value, int radix)
+ void String::append(int64_t value, int radix)
{
enum { kCount = 65 };
char* data = prepareForAppend(kCount);
@@ -293,6 +293,14 @@ namespace Slang
buffer->length += count;
}
+ void String::append(uint64_t value, int radix)
+ {
+ enum { kCount = 65 };
+ char* data = prepareForAppend(kCount);
+ auto count = IntToAscii(data, value, radix);
+ ReverseInternalAscii(data, count);
+ buffer->length += count;
+ }
void String::append(float val, const char * format)
{
diff --git a/source/core/slang-string.h b/source/core/slang-string.h
index 98fd51a07..ed333f8e8 100644
--- a/source/core/slang-string.h
+++ b/source/core/slang-string.h
@@ -257,9 +257,10 @@ namespace Slang
return getData() + getLength();
}
- void append(int value, int radix = 10);
- void append(unsigned int value, int radix = 10);
- void append(long long value, int radix = 10);
+ void append(int32_t value, int radix = 10);
+ void append(uint32_t value, int radix = 10);
+ void append(int64_t value, int radix = 10);
+ void append(uint64_t value, int radix = 10);
void append(float val, const char * format = "%g");
void append(double val, const char * format = "%g");
diff --git a/source/slang/compiler.cpp b/source/slang/compiler.cpp
index 9ca1d247f..25ae460c7 100644
--- a/source/slang/compiler.cpp
+++ b/source/slang/compiler.cpp
@@ -499,25 +499,14 @@ namespace Slang
}
#endif
-#if 0
- String emitSPIRVAssembly(
- ExtraContext& context)
+ List<uint8_t> emitSlangIRForEntryPoint(
+ EntryPointRequest* entryPoint)
{
- if(context.getTranslationUnitOptions().entryPoints.Count() == 0)
- {
- // TODO(tfoley): need to write diagnostics into this whole thing...
- fprintf(stderr, "no entry point specified\n");
- return "";
- }
-
- StringBuilder sb;
- for (auto entryPoint : context.getTranslationUnitOptions().entryPoints)
- {
- sb << emitSPIRVAssemblyForEntryPoint(context, entryPoint);
- }
- return sb.ProduceString();
+ SLANG_UNIMPLEMENTED_X("Slang IR Binary Generation");
}
-#endif
+
+ String emitSlangIRAssemblyForEntryPoint(
+ EntryPointRequest* entryPoint);
// Do emit logic for a single entry point
CompileResult emitEntryPoint(
@@ -578,6 +567,22 @@ namespace Slang
}
break;
+ case CodeGenTarget::SlangIR:
+ {
+ List<uint8_t> code = emitSlangIRForEntryPoint(entryPoint);
+ maybeDumpIntermediate(compileRequest, code.Buffer(), code.Count(), target);
+ result = CompileResult(code);
+ }
+ break;
+
+ case CodeGenTarget::SlangIRAssembly:
+ {
+ String code = emitSlangIRAssemblyForEntryPoint(entryPoint);
+ maybeDumpIntermediate(compileRequest, code.Buffer(), target);
+ result = CompileResult(code);
+ }
+ break;
+
case CodeGenTarget::None:
// The user requested no output
break;
@@ -957,6 +962,10 @@ namespace Slang
dumpIntermediateText(compileRequest, data, size, ".dxbc.asm");
break;
+ case CodeGenTarget::SlangIRAssembly:
+ dumpIntermediateText(compileRequest, data, size, ".slang-ir.asm");
+ break;
+
case CodeGenTarget::SPIRV:
dumpIntermediateBinary(compileRequest, data, size, ".spv");
{
@@ -972,6 +981,15 @@ namespace Slang
dumpIntermediateText(compileRequest, dxbcAssembly.begin(), dxbcAssembly.Length(), ".dxbc.asm");
}
break;
+
+ case CodeGenTarget::SlangIR:
+ dumpIntermediateBinary(compileRequest, data, size, ".slang-ir");
+ {
+ // TODO: need to support dissassembly from Slang IR binary
+// String slangIRAssembly = dissassembleSlangIR(compileRequest, data, size);
+// dumpIntermediateText(compileRequest, slangIRAssembly.begin(), slangIRAssembly.Length(), ".slang-ir.asm");
+ }
+ break;
}
}
diff --git a/source/slang/compiler.h b/source/slang/compiler.h
index b6eca1b36..7ec812d63 100644
--- a/source/slang/compiler.h
+++ b/source/slang/compiler.h
@@ -47,6 +47,8 @@ namespace Slang
DXBytecode = SLANG_DXBC,
DXBytecodeAssembly = SLANG_DXBC_ASM,
ReflectionJSON = SLANG_REFLECTION_JSON,
+ SlangIR = SLANG_IR,
+ SlangIRAssembly = SLANG_IR_ASM,
};
enum class LineDirectiveMode : SlangLineDirectiveMode
diff --git a/source/slang/emit.cpp b/source/slang/emit.cpp
index 209436da4..25a0b5323 100644
--- a/source/slang/emit.cpp
+++ b/source/slang/emit.cpp
@@ -1,6 +1,7 @@
// emit.cpp
#include "emit.h"
+#include "ir-insts.h"
#include "lower.h"
#include "lower-to-ir.h"
#include "name.h"
@@ -3966,11 +3967,16 @@ emitDeclImpl(decl, nullptr);
{
Simple,
Ptr,
+ Array,
};
Flavor flavor;
IRDeclaratorInfo* next;
- String const* name;
+ union
+ {
+ String const* name;
+ IRInst* elementCount;
+ };
};
void emitDeclarator(
@@ -3991,6 +3997,13 @@ emitDeclImpl(decl, nullptr);
emit("*");
emitDeclarator(context, declarator->next);
break;
+
+ case IRDeclaratorInfo::Flavor::Array:
+ emitDeclarator(context, declarator->next);
+ emit("[");
+ emitIROperand(context, declarator->elementCount);
+ emit("]");
+ break;
}
}
@@ -4008,6 +4021,13 @@ emitDeclImpl(decl, nullptr);
emit(((IRConstant*) inst)->u.floatVal);
break;
+ case kIROp_boolConst:
+ {
+ bool val = ((IRConstant*)inst)->u.intVal != 0;
+ emit(val ? "true" : "false");
+ }
+ break;
+
default:
SLANG_UNIMPLEMENTED_X("val case for emit");
break;
@@ -4028,6 +4048,8 @@ emitDeclImpl(decl, nullptr);
CASE(Float32Type, float);
CASE(Int32Type, int);
CASE(UInt32Type, uint);
+ CASE(VoidType, void);
+ CASE(BoolType, bool);
#undef CASE
@@ -4084,6 +4106,15 @@ emitDeclImpl(decl, nullptr);
}
break;
+ case kIROp_structuredBufferType:
+ {
+ auto tt = (IRBufferType*) type;
+ emit("StructuredBuffer<");
+ emitIRType(context, tt->getElementType(), nullptr);
+ emit(">");
+ }
+ break;
+
case kIROp_SamplerType:
{
// TODO: actually look at the flavor and emit the right name
@@ -4127,11 +4158,11 @@ emitDeclImpl(decl, nullptr);
IRType* type,
IRDeclaratorInfo* declarator)
{
- switch( type->op )
+ switch (type->op)
{
case kIROp_PtrType:
{
- auto ptrType = (IRPtrType*) type;
+ auto ptrType = (IRPtrType*)type;
IRDeclaratorInfo ptrDeclarator;
ptrDeclarator.flavor = IRDeclaratorInfo::Flavor::Ptr;
@@ -4140,6 +4171,18 @@ emitDeclImpl(decl, nullptr);
}
break;
+ case kIROp_arrayType:
+ {
+ auto arrayType = (IRArrayType*)type;
+
+ IRDeclaratorInfo arrayDeclarator;
+ arrayDeclarator.flavor = IRDeclaratorInfo::Flavor::Array;
+ arrayDeclarator.elementCount = arrayType->getElementCount();
+ arrayDeclarator.next = declarator;
+ emitIRType(context, arrayType->getElementType(), &arrayDeclarator);
+ }
+ break;
+
default:
emitIRSimpleType(context, type);
emitDeclarator(context, declarator);
@@ -4178,6 +4221,7 @@ emitDeclImpl(decl, nullptr);
case kIROp_IntLit:
case kIROp_FloatLit:
+ case kIROp_boolConst:
case kIROp_FieldAddress:
case kIROp_getElementPtr:
return true;
@@ -4280,17 +4324,9 @@ emitDeclImpl(decl, nullptr);
switch(inst->op)
{
case kIROp_IntLit:
- {
- auto irConst = (IRConstant*) inst;
- emit(irConst->u.intVal);
- }
- break;
-
case kIROp_FloatLit:
- {
- auto irConst = (IRConstant*) inst;
- emit(irConst->u.floatVal);
- }
+ case kIROp_boolConst:
+ emitIRSimpleValue(context, inst);
break;
case kIROp_Construct:
@@ -4351,8 +4387,39 @@ emitDeclImpl(decl, nullptr);
CASE(kIROp_Div, /);
CASE(kIROp_Mod, %);
+ CASE(kIROp_Lsh, <<);
+ CASE(kIROp_Rsh, >>);
+
+ CASE(kIROp_Eql, ==);
+ CASE(kIROp_Neq, !=);
+ CASE(kIROp_Greater, >);
+ CASE(kIROp_Less, <);
+ CASE(kIROp_Geq, >=);
+ CASE(kIROp_Leq, <=);
+
+ CASE(kIROp_BitAnd, &);
+ CASE(kIROp_BitXor, ^);
+ CASE(kIROp_BitOr, |);
+
+ CASE(kIROp_And, &&);
+ CASE(kIROp_Or, ||);
+
#undef CASE
+ case kIROp_Not:
+ {
+ if (inst->getType()->op == kIROp_BoolType)
+ {
+ emit("!");
+ }
+ else
+ {
+ emit("~");
+ }
+ emitIROperand(context, inst->getArg(1));
+ }
+ break;
+
case kIROp_Sample:
emitIROperand(context, inst->getArg(1));
emit(".Sample(");
@@ -4408,6 +4475,18 @@ emitDeclImpl(decl, nullptr);
emit("]");
break;
+ case kIROp_BufferStore:
+ emitIROperand(context, inst->getArg(1));
+ emit("[");
+ emitIROperand(context, inst->getArg(2));
+ emit("] = ");
+ emitIROperand(context, inst->getArg(3));
+ break;
+
+ case kIROp_GroupMemoryBarrierWithGroupSync:
+ emit("GroupMemoryBarrierWithGroupSync()");
+ break;
+
case kIROp_getElement:
case kIROp_getElementPtr:
emitIROperand(context, inst->getArg(1));
@@ -4426,8 +4505,29 @@ emitDeclImpl(decl, nullptr);
emit(")");
break;
+ case kIROp_swizzle:
+ {
+ auto ii = (IRSwizzle*)inst;
+ emitIROperand(context, ii->getBase());
+ emit(".");
+ UInt elementCount = ii->getElementCount();
+ for (UInt ee = 0; ee < elementCount; ++ee)
+ {
+ IRInst* irElementIndex = ii->getElementIndex(ee);
+ assert(irElementIndex->op == kIROp_IntLit);
+ IRConstant* irConst = (IRConstant*)irElementIndex;
+
+ UInt elementIndex = (UInt)irConst->u.intVal;
+ assert(elementIndex < 4);
+
+ char const* kComponents[] = { "x", "y", "z", "w" };
+ emit(kComponents[elementIndex]);
+ }
+ }
+ break;
+
default:
- emit("/* uhandled */");
+ emit("/* unhandled */");
break;
}
}
@@ -4478,6 +4578,51 @@ emitDeclImpl(decl, nullptr);
emitIROperand(context, ((IRReturnVal*) inst)->getVal());
emit(";\n");
break;
+
+ case kIROp_swizzleSet:
+ {
+ auto ii = (IRSwizzleSet*)inst;
+ emitIRInstResultDecl(context, inst);
+ emitIROperand(context, inst->getArg(1));
+ emit(";\n");
+ emitIROperand(context, inst);
+ emit(".");
+ UInt elementCount = ii->getElementCount();
+ for (UInt ee = 0; ee < elementCount; ++ee)
+ {
+ IRInst* irElementIndex = ii->getElementIndex(ee);
+ assert(irElementIndex->op == kIROp_IntLit);
+ IRConstant* irConst = (IRConstant*)irElementIndex;
+
+ UInt elementIndex = (UInt)irConst->u.intVal;
+ assert(elementIndex < 4);
+
+ char const* kComponents[] = { "x", "y", "z", "w" };
+ emit(kComponents[elementIndex]);
+ }
+ emit(" = ");
+ emitIROperand(context, inst->getArg(2));
+ emit(";\n");
+ }
+ break;
+
+#if 0
+ case kIROp_unconditionalBranch:
+ emit("// unconditionalBranch ");
+ emitIROperand(context, inst->getArg(1));
+ emit("\n");
+ break;
+
+ case kIROp_conditionalBranch:
+ emit("// conditionalBranch ");
+ emitIROperand(context, inst->getArg(1));
+ emit(" ");
+ emitIROperand(context, inst->getArg(2));
+ emit(" ");
+ emitIROperand(context, inst->getArg(3));
+ emit("\n");
+ break;
+#endif
}
}
@@ -4515,6 +4660,238 @@ emitDeclImpl(decl, nullptr);
}
}
+ // We want to emit a range of code in the IR, represented
+ // by the blocks that are logically in the interval [begin, end)
+ // which we consider as a single-entry multiple-exit region.
+ //
+ // Note: because there are multiple exists, control flow
+ // may exit this region with operations that do *not* branch
+ // to `end`, but such non-local control flow will hopefully
+ // be captured.
+ //
+ void emitIRStmtsForBlocks(
+ EmitContext* context,
+ IRBlock* begin,
+ IRBlock* end)
+ {
+ IRBlock* block = begin;
+ while(block != end)
+ {
+ // Start by emitting the non-terminator instructions in the block.
+ auto terminator = block->lastChild;
+ assert(isTerminatorInst(terminator));
+ for (auto inst = block->firstChild; inst != terminator; inst = inst->nextInst)
+ {
+ emitIRInst(context, inst);
+ }
+
+ // Now look at the terminator instruction, which will tell us what we need to emit next.
+
+ switch (terminator->op)
+ {
+ default:
+ SLANG_UNEXPECTED("terminator inst");
+ return;
+
+ case kIROp_ReturnVal:
+ case kIROp_ReturnVoid:
+ emitIRInst(context, terminator);
+ return;
+
+ case kIROp_if:
+ {
+ // One-sided `if` statement
+ auto t = (IRIf*)terminator;
+
+ auto trueBlock = t->getTrueBlock();
+ auto afterBlock = t->getAfterBlock();
+
+ emit("if(");
+ emitIROperand(context, t->getCondition());
+ emit(")\n{\n");
+ emitIRStmtsForBlocks(
+ context,
+ trueBlock,
+ afterBlock);
+ emit("}\n");
+
+ // Continue with the block after the `if`
+ block = afterBlock;
+ }
+ break;
+
+ case kIROp_ifElse:
+ {
+ // Two-sided `if` statement
+ auto t = (IRIfElse*)terminator;
+
+ auto trueBlock = t->getTrueBlock();
+ auto falseBlock = t->getFalseBlock();
+ auto afterBlock = t->getAfterBlock();
+
+ emit("if(");
+ emitIROperand(context, t->getCondition());
+ emit(")\n{\n");
+ emitIRStmtsForBlocks(
+ context,
+ trueBlock,
+ afterBlock);
+ emit("}\nelse\n{\n");
+ emitIRStmtsForBlocks(
+ context,
+ falseBlock,
+ afterBlock);
+ emit("}\n");
+
+ // Continue with the block after the `if`
+ block = afterBlock;
+ }
+ break;
+
+ case kIROp_loop:
+ {
+ // Header for a `while` or `for` loop
+ auto t = (IRLoop*)terminator;
+
+ auto targetBlock = t->getTargetBlock();
+ auto breakBlock = t->getBreakBlock();
+ auto continueBlock = t->getContinueBlock();
+
+ if (auto loopControlDecoration = t->findDecoration<IRLoopControlDecoration>())
+ {
+ switch (loopControlDecoration->mode)
+ {
+ case kIRLoopControl_Unroll:
+ emit("[unroll]\n");
+ break;
+
+ default:
+ break;
+ }
+ }
+
+ // The challenging case for a loop is when
+ // there is a `continue` block that we
+ // need to deal with.
+ //
+ if (continueBlock == targetBlock)
+ {
+ // There is no continue block, so
+ // we only need to emit an endless
+ // loop and then manually `break`
+ // out of it in the right place(s)
+ emit("for(;;)\n{\n");
+
+ emitIRStmtsForBlocks(
+ context,
+ targetBlock,
+ nullptr);
+
+ emit("}\n");
+ }
+ else
+ {
+ // Okay, we've got a `continue` block,
+ // which means we really want to emit
+ // something akin to:
+ //
+ // for(;; <continueBlock>) { <bodyBlock> }
+ //
+ // In principle this isn't so bad, since the
+ // first case is just interVal [`continueBlock`, `targetBlock`)
+ // and the latter is the interval [`targetBlock`, `continueBlock`).
+ //
+ // The challenge of course is that a `for` statement
+ // only supports *expressions* in the continue part,
+ // and we might have expanded things into multiple
+ // instructions (especially if we inlined or desugared anything).
+ //
+ // There are a variety of ways we can support lowering this,
+ // but for now we are going to do something expedient
+ // that mimics what `fxc` seems to do:
+ //
+ // - Output loop body as `for(;;) { <bodyBlock> <continueBlock> }`
+ // - At any `continue` site, output `{ <continueBlock>; continue; }`
+ //
+ // This isn't ideal because it leads to code duplication, but
+ // it matches what `fxc` does so hopefully it will be the
+ // best option for our tests.
+ //
+
+ emit("for(;;)\n{\n");
+
+ // TODO: Okay, we *said* we'd do this special
+ // handling of the `continue` sites, but
+ // we aren't actually setting anything up here...
+ //
+
+ emitIRStmtsForBlocks(
+ context,
+ targetBlock,
+ nullptr);
+
+ emit("}\n");
+
+ }
+
+ // Continue with the block after the loop
+ block = breakBlock;
+ }
+ break;
+
+ case kIROp_break:
+ emit("break;\n");
+ return;
+
+ case kIROp_continue:
+ emit("continue;\n");
+ return;
+
+ case kIROp_loopTest:
+ {
+ // Loop condition being tested
+ auto t = (IRLoopTest*)terminator;
+
+ auto afterBlock = t->getTrueBlock();
+
+ emit("if(");
+ emitIROperand(context, t->getCondition());
+ emit(")\n{} else break;\n");
+
+ // Continue with the block after the test
+ block = afterBlock;
+ }
+ break;
+
+ case kIROp_unconditionalBranch:
+ {
+ // Unconditional branch as part of normal
+ // control flow. This is either a forward
+ // edge to the "next" block in an ordinary
+ // block, or a backward edge to the top
+ // of a loop.
+ auto t = (IRUnconditionalBranch*)terminator;
+ block = t->getTargetBlock();
+ }
+ break;
+
+ case kIROp_conditionalBranch:
+ SLANG_UNEXPECTED("terminator inst");
+ return;
+ }
+
+ // If we reach this point, then we've emitted
+ // one block, and we have a new block where
+ // control flow continues.
+ //
+ // We need to handle a special case here,
+ // when control flow jumps back to the
+ // starting block of the range we were
+ // asked to work with:
+ if (block == begin) return;
+ }
+ }
+
void emitIRFunc(
EmitContext* context,
IRFunc* func)
@@ -4522,6 +4899,19 @@ emitDeclImpl(decl, nullptr);
auto funcType = func->getType();
auto resultType = func->getResultType();
+ // Deal with decorations that need
+ // to be emitted as attributes
+ if (auto threadGroupSizeDecoration = func->findDecoration<IRComputeThreadGroupSizeDecoration>())
+ {
+ emit("[numthreads(");
+ for (int ii = 0; ii < 3; ++ii)
+ {
+ if (ii != 0) emit(", ");
+ Emit(threadGroupSizeDecoration->sizeAlongAxis[ii]);
+ }
+ emit(")]\n");
+ }
+
auto name = getName(func);
emitIRType(context, resultType, name);
@@ -4535,6 +4925,8 @@ emitDeclImpl(decl, nullptr);
auto paramName = getName(pp);
emitIRType(context, pp->getType(), paramName);
+
+ emitIRSemantics(context, pp);
}
emit(")");
@@ -4548,6 +4940,10 @@ emitDeclImpl(decl, nullptr);
emit("\n{\n");
// Need to emit the operations in the blocks of the function
+
+ emitIRStmtsForBlocks(context, func->getFirstBlock(), nullptr);
+
+#if 0
for( auto bb = func->getFirstBlock(); bb; bb = bb->getNextBlock() )
{
// TODO: need to handle control flow and so forth...
@@ -4556,6 +4952,7 @@ emitDeclImpl(decl, nullptr);
emitIRInst(context, ii);
}
}
+#endif
emit("}\n");
}
@@ -4687,7 +5084,8 @@ emitDeclImpl(decl, nullptr);
IRVar* varDecl)
{
auto allocatedType = varDecl->getType();
- auto varType = ((IRPtrType*) allocatedType)->getValueType();
+ auto varType = allocatedType->getValueType();
+ auto addressSpace = allocatedType->getAddressSpace();
switch( varType->op )
{
@@ -4706,6 +5104,16 @@ emitDeclImpl(decl, nullptr);
emitIRVarModifiers(context, layout);
+ switch (addressSpace)
+ {
+ default:
+ break;
+
+ case kIRAddressSpace_GroupShared:
+ emit("groupshared ");
+ break;
+ }
+
emitIRType(context, varType, getName(varDecl));
emitIRSemantics(context, varDecl);
diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang
index 7c7d3fda0..0ca838ff5 100644
--- a/source/slang/hlsl.meta.slang
+++ b/source/slang/hlsl.meta.slang
@@ -29,7 +29,13 @@ __magic_type(HLSLByteAddressBufferType) struct ByteAddressBuffer
__intrinsic_op uint4 Load4(int location, out uint status);
};
-__generic<T> __magic_type(HLSLStructuredBufferType) struct StructuredBuffer
+__generic<T>
+__magic_type(HLSLStructuredBufferType)
+__intrinsic_type(${{
+ // TODO: we really need a simple way to write an "expression splice"
+ sb << kIROp_structuredBufferType;
+}})
+struct StructuredBuffer
{
__intrinsic_op void GetDimensions(
out uint numStructs,
diff --git a/source/slang/hlsl.meta.slang.cpp b/source/slang/hlsl.meta.slang.cpp
index a3744f620..49254ac60 100644
--- a/source/slang/hlsl.meta.slang.cpp
+++ b/source/slang/hlsl.meta.slang.cpp
@@ -29,7 +29,14 @@ sb << " __intrinsic_op uint4 Load4(int location);\n";
sb << " __intrinsic_op uint4 Load4(int location, out uint status);\n";
sb << "};\n";
sb << "\n";
-sb << "__generic<T> __magic_type(HLSLStructuredBufferType) struct StructuredBuffer\n";
+sb << "__generic<T>\n";
+sb << "__magic_type(HLSLStructuredBufferType)\n";
+sb << "__intrinsic_type(";
+
+ // TODO: we really need a simple way to write an "expression splice"
+ sb << kIROp_structuredBufferType;
+sb << ")\n";
+sb << "struct StructuredBuffer\n";
sb << "{\n";
sb << " __intrinsic_op void GetDimensions(\n";
sb << " out uint numStructs,\n";
diff --git a/source/slang/ir-inst-defs.h b/source/slang/ir-inst-defs.h
index dab7aa3d7..deaf02d56 100644
--- a/source/slang/ir-inst-defs.h
+++ b/source/slang/ir-inst-defs.h
@@ -13,24 +13,29 @@
// Invalid operation: should not appear in valid code
INST(Nop, nop, 0, 0)
-INST(TypeType, type.type, 0, 0)
-INST(VoidType, type.void, 0, 0)
-INST(BlockType, type.block, 0, 0)
-INST(VectorType, type.vector, 2, 0)
-INST(MatrixType, matrixType, 3, 0)
-INST(BoolType, type.bool, 0, 0)
-INST(Float32Type, type.f32, 0, 0)
-INST(Int32Type, type.i32, 0, 0)
-INST(UInt32Type, type.u32, 0, 0)
-INST(StructType, type.struct, 0, PARENT)
-INST(FuncType, func_type, 0, 0)
-INST(PtrType, ptr_type, 1, 0)
-INST(TextureType, texture_type, 2, 0)
-INST(SamplerType, sampler_type, 1, 0)
-INST(ConstantBufferType, constant_buffer_type, 1, 0)
-INST(TextureBufferType, texture_buffer_type, 1, 0)
-INST(readWriteStructuredBufferType, readWriteStructuredBufferType, 1, 0)
-
+INST(TypeType, Type, 0, 0)
+INST(VoidType, Void, 0, 0)
+INST(BlockType, Block, 0, 0)
+INST(VectorType, Vec, 2, 0)
+INST(MatrixType, Mat, 3, 0)
+INST(arrayType, Array, 2, 0)
+
+INST(BoolType, Bool, 0, 0)
+INST(Float32Type, Float32, 0, 0)
+INST(Int32Type, Int32, 0, 0)
+INST(UInt32Type, UInt32, 0, 0)
+INST(StructType, Struct, 0, PARENT)
+INST(FuncType, Func, 0, 0)
+INST(PtrType, Ptr, 1, 0)
+INST(TextureType, Texture, 2, 0)
+INST(SamplerType, SamplerState, 1, 0)
+INST(ConstantBufferType, ConstantBuffer, 1, 0)
+INST(TextureBufferType, TextureBuffer, 1, 0)
+
+INST(structuredBufferType, StructuredBuffer, 1, 0)
+INST(readWriteStructuredBufferType, RWStructuredBuffer, 1, 0)
+
+INST(boolConst, boolConst, 0, 0)
INST(IntLit, integer_constant, 0, 0)
INST(FloatLit, float_constant, 0, 0)
@@ -57,9 +62,48 @@ INST(FieldAddress, get_field_addr, 2, 0)
INST(getElement, getElement, 2, 0)
INST(getElementPtr, getElementPtr, 2, 0)
+// A swizzle of a vector:
+//
+// %dst = swizzle %src %idx0 %idx1 ...
+//
+// where:
+// - `src` is a vector<T,N>
+// - `dst` is a vector<T,M>
+// - `idx0` through `idx[M-1]` are literal integers
+//
+INST(swizzle, swizzle, 1, 0)
+
+// Setting a vector via swizzle
+//
+// %dst = swizzle %base %src %idx0 %idx1 ...
+//
+// where:
+// - `base` is a vector<T,N>
+// - `dst` is a vector<T,N>
+// - `src` is a vector<T,M>
+// - `idx0` through `idx[M-1]` are literal integers
+//
+// The semantics of the op is:
+//
+// dst = base;
+// for(ii : 0 ... M-1 )
+// dst[ii] = src[idx[ii]];
+//
+INST(swizzleSet, swizzleSet, 2, 0)
+
+
INST(ReturnVal, return_val, 1, 0)
INST(ReturnVoid, return_void, 1, 0)
+INST(unconditionalBranch, unconditionalBranch, 1, 0)
+INST(break, break, 1, 0)
+INST(continue, continue, 1, 0)
+INST(loop, loop, 3, 0)
+
+INST(conditionalBranch, conditionalBranch, 1, 0)
+INST(if, if, 3, 0)
+INST(ifElse, ifElse, 4, 0)
+INST(loopTest, loopTest, 3, 0)
INST(Add, add, 2, 0)
INST(Sub, sub, 2, 0)
@@ -139,6 +183,8 @@ INST(Sample, sample, 3, 0)
INST(SampleGrad, sampleGrad, 4, 0)
+INST(GroupMemoryBarrierWithGroupSync, GroupMemoryBarrierWithGroupSync, 0, 0)
+
PSEUDO_INST(Pos)
PSEUDO_INST(PreInc)
diff --git a/source/slang/ir-insts.h b/source/slang/ir-insts.h
new file mode 100644
index 000000000..846e87bb6
--- /dev/null
+++ b/source/slang/ir-insts.h
@@ -0,0 +1,569 @@
+// ir-insts.h
+#ifndef SLANG_IR_INSTS_H_INCLUDED
+#define SLANG_IR_INSTS_H_INCLUDED
+
+// This file extends the core definitions in `ir.h`
+// with a wider variety of concrete instructions,
+// and a "builder" abstraction.
+//
+// TODO: the builder probably needs its own file.
+
+#include "compiler.h"
+#include "ir.h"
+
+namespace Slang {
+
+class Decl;
+
+// Associates an IR-level decoration with a source declaration
+// in the high-level AST, that can be used to extract
+// additional information that informs code emission.
+struct IRHighLevelDeclDecoration : IRDecoration
+{
+ enum { kDecorationOp = kIRDecorationOp_HighLevelDecl };
+
+ Decl* decl;
+};
+
+// Associates an IR-level decoration with a source layout
+struct IRLayoutDecoration : IRDecoration
+{
+ enum { kDecorationOp = kIRDecorationOp_Layout };
+
+ Layout* layout;
+};
+
+// Identifies a function as an entry point for some stage
+struct IREntryPointDecoration : IRDecoration
+{
+ enum { kDecorationOp = kIRDecorationOp_EntryPoint };
+
+ Profile profile;
+};
+
+// Associates a compute-shader entry point function
+// with a thread-group size.
+struct IRComputeThreadGroupSizeDecoration : IRDecoration
+{
+ enum { kDecorationOp = kIRDecorationOp_ComputeThreadGroupSize };
+
+ UInt sizeAlongAxis[3];
+};
+
+enum IRLoopControl
+{
+ kIRLoopControl_Unroll,
+};
+
+struct IRLoopControlDecoration : IRDecoration
+{
+ enum { kDecorationOp = kIRDecorationOp_LoopControl };
+
+ IRLoopControl mode;
+};
+
+// Types
+
+struct IRVectorType : IRType
+{
+ IRUse elementType;
+ IRUse elementCount;
+
+ IRType* getElementType() { return (IRType*) elementType.usedValue; }
+ IRInst* getElementCount() { return elementCount.usedValue; }
+};
+
+struct IRMatrixType : IRType
+{
+ IRUse elementType;
+ IRUse rowCount;
+ IRUse columnCount;
+
+ IRType* getElementType() { return (IRType*) elementType.usedValue; }
+ IRInst* getRowCount() { return rowCount.usedValue; }
+ IRInst* getColumnCount() { return columnCount.usedValue; }
+};
+
+struct IRArrayType : IRType
+{
+ IRUse elementType;
+ IRUse elementCount;
+
+ IRType* getElementType() { return (IRType*) elementType.usedValue; }
+ IRInst* getElementCount() { return elementCount.usedValue; }
+};
+
+
+
+struct IRFuncType : IRType
+{
+ IRUse resultType;
+ // parameter tyeps are varargs...
+
+ IRType* getResultType() { return (IRType*) resultType.usedValue; }
+ UInt getParamCount()
+ {
+ return getArgCount() - 2;
+ }
+ IRType* getParamType(UInt index)
+ {
+ return (IRType*) getArg(2 + index);
+ }
+};
+
+// Address spaces for IR pointers
+enum IRAddressSpace : UInt
+{
+ // A default address space for things like local variables
+ kIRAddressSpace_Default,
+
+ // Address space for HLSL `groupshared` allocations
+ kIRAddressSpace_GroupShared,
+};
+
+struct IRPtrType : IRType
+{
+ IRUse valueType;
+ IRUse addressSpace;
+
+ IRType* getValueType() { return (IRType*) valueType.usedValue; }
+
+ IRAddressSpace getAddressSpace()
+ {
+ return IRAddressSpace(
+ ((IRConstant*)addressSpace.usedValue)->u.intVal);
+ }
+};
+
+struct IRTextureType : IRType
+{
+ IRUse flavor;
+ IRUse elementType;
+
+ IRIntegerValue getFlavor() { return ((IRConstant*) flavor.usedValue)->u.intVal; }
+ IRType* getElementType() { return (IRType*) elementType.usedValue; }
+};
+
+struct IRBufferType : IRType
+{
+ IRUse elementType;
+ IRType* getElementType() { return (IRType*) elementType.usedValue; }
+};
+
+
+struct IRUniformBufferType : IRType
+{
+ IRUse elementType;
+ IRType* getElementType() { return (IRType*) elementType.usedValue; }
+};
+
+struct IRConstantBufferType : IRUniformBufferType {};
+struct IRTextureBufferType : IRUniformBufferType {};
+
+struct IRCall : IRInst
+{
+ IRUse func;
+};
+
+struct IRLoad : IRInst
+{
+ IRUse ptr;
+};
+
+struct IRStore : IRInst
+{
+ IRUse ptr;
+ IRUse val;
+};
+
+struct IRStructField;
+struct IRFieldExtract : IRInst
+{
+ IRUse base;
+ IRUse field;
+
+ IRInst* getBase() { return base.usedValue; }
+ IRStructField* getField() { return (IRStructField*) field.usedValue; }
+};
+
+struct IRFieldAddress : IRInst
+{
+ IRUse base;
+ IRUse field;
+
+ IRInst* getBase() { return base.usedValue; }
+ IRStructField* getField() { return (IRStructField*) field.usedValue; }
+};
+
+// Terminators
+
+struct IRReturn : IRTerminatorInst
+{};
+
+struct IRReturnVal : IRReturn
+{
+ IRUse val;
+
+ IRInst* getVal() { return val.usedValue; }
+};
+
+struct IRReturnVoid : IRReturn
+{};
+
+struct IRBlock;
+
+struct IRUnconditionalBranch : IRTerminatorInst
+{
+ IRUse block;
+
+ IRBlock* getTargetBlock() { return (IRBlock*)block.usedValue; }
+};
+
+// Special cases of unconditional branch, to handle
+// structured control flow:
+struct IRBreak : IRUnconditionalBranch {};
+struct IRContinue : IRUnconditionalBranch {};
+
+// The start of a loop is a special control-flow
+// instruction, that records relevant information
+// about the loop structure:
+struct IRLoop : IRUnconditionalBranch
+{
+ // The next block after the loop, which
+ // is where we expect control flow to
+ // re-converge, and also where a
+ // `break` will target.
+ IRUse breakBlock;
+
+ // The block where control flow will go
+ // on a `continue`.
+ IRUse continueBlock;
+
+ IRBlock* getBreakBlock() { return (IRBlock*)breakBlock.usedValue; }
+ IRBlock* getContinueBlock() { return (IRBlock*)continueBlock.usedValue; }
+};
+
+struct IRConditionalBranch : IRTerminatorInst
+{
+ IRUse condition;
+ IRUse trueBlock;
+ IRUse falseBlock;
+
+ IRInst* getCondition() { return condition.usedValue; }
+ IRBlock* getTrueBlock() { return (IRBlock*)trueBlock.usedValue; }
+ IRBlock* getFalseBlock() { return (IRBlock*)falseBlock.usedValue; }
+};
+
+// A conditional branch that represent the test inside a loop
+struct IRLoopTest : IRConditionalBranch
+{
+};
+
+// A conditional branch that represents a one-sided `if`:
+//
+// if( <condition> ) { <trueBlock> }
+// <falseBlock>
+struct IRIf : IRConditionalBranch
+{
+ IRBlock* getAfterBlock() { return getFalseBlock(); }
+};
+
+// A conditional branch that represents a two-sided `if`:
+//
+// if( <condition> ) { <trueBlock> }
+// else { <falseBlock> }
+// <afterBlock>
+//
+struct IRIfElse : IRConditionalBranch
+{
+ IRUse afterBlock;
+
+ IRBlock* getAfterBlock() { return (IRBlock*)afterBlock.usedValue; }
+};
+
+struct IRSwizzle : IRReturn
+{
+ IRUse base;
+
+ IRInst* getBase() { return base.usedValue; }
+ UInt getElementCount()
+ {
+ return getArgCount() - 2;
+ }
+ IRInst* getElementIndex(UInt index)
+ {
+ return getArg(index + 2);
+ }
+};
+
+struct IRSwizzleSet : IRReturn
+{
+ IRUse base;
+ IRUse source;
+
+ IRInst* getBase() { return base.usedValue; }
+ IRInst* getSource() { return source.usedValue; }
+ UInt getElementCount()
+ {
+ return getArgCount() - 3;
+ }
+ IRInst* getElementIndex(UInt index)
+ {
+ return getArg(index + 3);
+ }
+};
+
+// "Parent" Instructions (Declarations)
+
+struct IRStructField : IRInst
+{
+ IRType* getFieldType() { return (IRType*) type.usedValue; }
+
+ IRStructField* getNextField() { return (IRStructField*) nextInst; }
+};
+
+struct IRStructDecl : IRParentInst
+{
+ IRStructField* getFirstField() { return (IRStructField*) firstChild; }
+ IRStructField* getLastField() { return (IRStructField*) lastChild; }
+};
+
+
+struct IRVar : IRInst
+{
+ IRPtrType* getType()
+ {
+ return (IRPtrType*)type.usedValue;
+ }
+};
+
+
+// Description of an instruction to be used for global value numbering
+struct IRInstKey
+{
+ IRInst* inst;
+
+ int GetHashCode();
+};
+
+bool operator==(IRInstKey const& left, IRInstKey const& right);
+
+struct IRConstantKey
+{
+ IRConstant* inst;
+
+ int GetHashCode();
+};
+bool operator==(IRConstantKey const& left, IRConstantKey const& right);
+
+struct SharedIRBuilder
+{
+ // The module that will own all of the IR
+ IRModule* module;
+
+ Dictionary<IRInstKey, IRInst*> globalValueNumberingMap;
+ Dictionary<IRConstantKey, IRConstant*> constantMap;
+};
+
+struct IRBuilder
+{
+ // Shared state for all IR builders working on the same module
+ SharedIRBuilder* shared;
+
+ IRModule* getModule() { return shared->module; }
+
+ // The parent instruction to add children to.
+ IRParentInst* parentInst;
+
+ void addInst(IRParentInst* parent, IRInst* inst);
+ void addInst(IRInst* inst);
+
+ IRType* getBaseType(BaseType flavor);
+ IRType* getBoolType();
+ IRType* getVectorType(IRType* elementType, IRValue* elementCount);
+ IRType* getMatrixType(
+ IRType* elementType,
+ IRValue* rowCount,
+ IRValue* columnCount);
+ IRType* getArrayType(IRType* elementType, IRValue* elementCount);
+ IRType* getArrayType(IRType* elementType);
+
+ IRType* getTypeType();
+ IRType* getVoidType();
+ IRType* getBlockType();
+
+ IRType* getIntrinsicType(
+ IROp op,
+ UInt argCount,
+ IRValue* const* args);
+
+ IRStructDecl* createStructType();
+ IRStructField* createStructField(IRType* fieldType);
+
+ IRType* getFuncType(
+ UInt paramCount,
+ IRType* const* paramTypes,
+ IRType* resultType);
+
+ IRType* getPtrType(
+ IRType* valueType,
+ IRAddressSpace addressSpace);
+
+ IRType* getPtrType(
+ IRType* valueType);
+
+ IRValue* getBoolValue(bool value);
+ IRValue* getIntValue(IRType* type, IRIntegerValue value);
+ IRValue* getFloatValue(IRType* type, IRFloatingPointValue value);
+
+ IRInst* emitCallInst(
+ IRType* type,
+ IRValue* func,
+ UInt argCount,
+ IRValue* const* args);
+
+ IRInst* emitIntrinsicInst(
+ IRType* type,
+ IROp op,
+ UInt argCount,
+ IRValue* const* args);
+
+ IRInst* emitConstructorInst(
+ IRType* type,
+ UInt argCount,
+ IRValue* const* args);
+
+ IRModule* createModule();
+
+ IRFunc* createFunc();
+
+ IRBlock* createBlock();
+ IRBlock* emitBlock();
+
+ IRParam* emitParam(
+ IRType* type);
+
+ IRVar* emitVar(
+ IRType* type,
+ IRAddressSpace addressSpace);
+
+ IRVar* emitVar(
+ IRType* type);
+
+ IRInst* emitLoad(
+ IRValue* ptr);
+
+ IRInst* emitStore(
+ IRValue* dstPtr,
+ IRValue* srcVal);
+
+ IRInst* emitFieldExtract(
+ IRType* type,
+ IRValue* base,
+ IRStructField* field);
+
+ IRInst* emitFieldAddress(
+ IRType* type,
+ IRValue* basePtr,
+ IRStructField* field);
+
+ IRInst* emitElementExtract(
+ IRType* type,
+ IRValue* base,
+ IRValue* index);
+
+ IRInst* emitElementAddress(
+ IRType* type,
+ IRValue* basePtr,
+ IRValue* index);
+
+ IRInst* emitSwizzle(
+ IRType* type,
+ IRValue* base,
+ UInt elementCount,
+ IRValue* const* elementIndices);
+
+ IRInst* emitSwizzle(
+ IRType* type,
+ IRValue* base,
+ UInt elementCount,
+ UInt const* elementIndices);
+
+ IRInst* emitSwizzleSet(
+ IRType* type,
+ IRValue* base,
+ IRValue* source,
+ UInt elementCount,
+ IRValue* const* elementIndices);
+
+ IRInst* emitSwizzleSet(
+ IRType* type,
+ IRValue* base,
+ IRValue* source,
+ UInt elementCount,
+ UInt const* elementIndices);
+
+ IRInst* emitReturn(
+ IRValue* val);
+
+ IRInst* emitReturn();
+
+ IRInst* emitBranch(
+ IRBlock* block);
+
+ IRInst* emitBreak(
+ IRBlock* target);
+
+ IRInst* emitContinue(
+ IRBlock* target);
+
+ IRInst* emitLoop(
+ IRBlock* target,
+ IRBlock* breakBlock,
+ IRBlock* continueBlock);
+
+ IRInst* emitBranch(
+ IRValue* val,
+ IRBlock* trueBlock,
+ IRBlock* falseBlock);
+
+ IRInst* emitIf(
+ IRValue* val,
+ IRBlock* trueBlock,
+ IRBlock* afterBlock);
+
+ IRInst* emitIfElse(
+ IRValue* val,
+ IRBlock* trueBlock,
+ IRBlock* falseBlock,
+ IRBlock* afterBlock);
+
+ IRInst* emitLoopTest(
+ IRValue* val,
+ IRBlock* bodyBlock,
+ IRBlock* breakBlock);
+
+ IRDecoration* addDecorationImpl(
+ IRInst* inst,
+ UInt decorationSize,
+ IRDecorationOp op);
+
+ template<typename T>
+ T* addDecoration(IRInst* inst, IRDecorationOp op)
+ {
+ return (T*) addDecorationImpl(inst, sizeof(T), op);
+ }
+
+ template<typename T>
+ T* addDecoration(IRInst* inst)
+ {
+ return (T*) addDecorationImpl(inst, sizeof(T), IRDecorationOp(T::kDecorationOp));
+ }
+
+ IRHighLevelDeclDecoration* addHighLevelDeclDecoration(IRInst* inst, Decl* decl);
+ IRLayoutDecoration* addLayoutDecoration(IRInst* inst, Layout* layout);
+};
+
+}
+
+#endif
diff --git a/source/slang/ir.cpp b/source/slang/ir.cpp
index 3a6410125..c48880cab 100644
--- a/source/slang/ir.cpp
+++ b/source/slang/ir.cpp
@@ -1,5 +1,6 @@
// ir.cpp
#include "ir.h"
+#include "ir-insts.h"
#include "../core/basic.h"
@@ -65,7 +66,12 @@ namespace Slang
return nullptr;
}
- //
+ // IRFunc
+
+ IRType* IRFunc::getResultType() { return getType()->getResultType(); }
+ UInt IRFunc::getParamCount() { return getType()->getParamCount(); }
+ IRType* IRFunc::getParamType(UInt index) { return getType()->getParamType(index); }
+
IRParam* IRFunc::getFirstParam()
{
@@ -81,6 +87,8 @@ namespace Slang
return (IRParam*) firstInst;
}
+ // IRParam
+
IRParam* IRParam::getNextParam()
{
auto next = nextInst;
@@ -94,6 +102,36 @@ namespace Slang
//
+ bool isTerminatorInst(IROp op)
+ {
+ switch (op)
+ {
+ default:
+ return false;
+
+ case kIROp_ReturnVal:
+ case kIROp_ReturnVoid:
+ case kIROp_unconditionalBranch:
+ case kIROp_conditionalBranch:
+ case kIROp_break:
+ case kIROp_continue:
+ case kIROp_loop:
+ case kIROp_if:
+ case kIROp_ifElse:
+ case kIROp_loopTest:
+ return true;
+ }
+ }
+
+ bool isTerminatorInst(IRInst* inst)
+ {
+ if (!inst) return false;
+ return isTerminatorInst(inst->op);
+ }
+
+
+ //
+
// Add an instruction to a specific parent
void IRBuilder::addInst(IRParentInst* parent, IRInst* inst)
{
@@ -306,6 +344,28 @@ namespace Slang
varArgs);
}
+ template<typename T>
+ static T* createInstWithTrailingArgs(
+ IRBuilder* builder,
+ IROp op,
+ IRType* type,
+ IRValue* arg1,
+ UInt varArgCount,
+ IRValue* const* varArgs)
+ {
+ IRValue* fixedArgs[] = { arg1 };
+ UInt fixedArgCount = sizeof(fixedArgs) / sizeof(fixedArgs[0]);
+
+ return (T*)createInstImpl(
+ builder,
+ sizeof(T) + varArgCount * sizeof(IRUse),
+ op,
+ type,
+ fixedArgCount,
+ fixedArgs,
+ varArgCount,
+ varArgs);
+ }
//
bool operator==(IRInstKey const& left, IRInstKey const& right)
@@ -637,6 +697,8 @@ namespace Slang
{
switch( flavor )
{
+ case BaseType::Void: return getVoidType();
+
case BaseType::Bool: return getBaseTypeImpl(this, kIROp_BoolType);
case BaseType::Float: return getBaseTypeImpl(this, kIROp_Float32Type);
case BaseType::Int: return getBaseTypeImpl(this, kIROp_Int32Type);
@@ -677,6 +739,34 @@ namespace Slang
columnCount);
}
+ IRType* IRBuilder::getArrayType(IRType* elementType, IRValue* elementCount)
+ {
+ // The client requests an unsized array by passing `nullptr` for
+ // the `elementCount`.
+ //
+ // We currently encode an unsized array as an ordinary array with
+ // zero elements. TODO: carefully consider this choice.
+ if (!elementCount)
+ {
+ elementCount = getIntValue(
+ getBaseType(BaseType::Int),
+ 0);
+ }
+
+ return findOrEmitInst<IRArrayType>(
+ this,
+ kIROp_arrayType,
+ getTypeType(),
+ elementType,
+ elementCount);
+ }
+
+ IRType* IRBuilder::getArrayType(IRType* elementType)
+ {
+ return getArrayType(elementType, nullptr);
+ }
+
+
IRType* IRBuilder::getTypeType()
{
return findOrEmitInst<IRType>(
@@ -753,21 +843,38 @@ namespace Slang
}
IRType* IRBuilder::getPtrType(
- IRType* valueType)
+ IRType* valueType,
+ IRAddressSpace addressSpace)
{
+ auto uintType = getBaseType(BaseType::UInt);
+ auto irAddressSpace = getIntValue(uintType, addressSpace);
+
auto inst = findOrEmitInst<IRPtrType>(
this,
kIROp_PtrType,
getTypeType(),
- 1,
- (IRValue* const*) &valueType);
+ valueType,
+ irAddressSpace);
return inst;
}
- IRValue* IRBuilder::getBoolValue(bool value)
+ IRType* IRBuilder::getPtrType(
+ IRType* valueType)
{
- SLANG_UNIMPLEMENTED_X("IR");
+ return getPtrType(valueType, kIRAddressSpace_Default);
+ }
+
+
+ IRValue* IRBuilder::getBoolValue(bool inValue)
+ {
+ IRIntegerValue value = inValue;
+ return findOrEmitConstant(
+ this,
+ kIROp_boolConst,
+ getBoolType(),
+ sizeof(value),
+ &value);
}
IRValue* IRBuilder::getIntValue(IRType* type, IRIntegerValue value)
@@ -883,9 +990,10 @@ namespace Slang
}
IRVar* IRBuilder::emitVar(
- IRType* type)
+ IRType* type,
+ IRAddressSpace addressSpace)
{
- auto allocatedType = getPtrType(type);
+ auto allocatedType = getPtrType(type, addressSpace);
auto inst = createInst<IRVar>(
this,
kIROp_Var,
@@ -894,6 +1002,13 @@ namespace Slang
return inst;
}
+
+ IRVar* IRBuilder::emitVar(
+ IRType* type)
+ {
+ return emitVar(type, kIRAddressSpace_Default);
+ }
+
IRInst* IRBuilder::emitLoad(
IRValue* ptr)
{
@@ -996,6 +1111,83 @@ namespace Slang
return inst;
}
+ IRInst* IRBuilder::emitSwizzle(
+ IRType* type,
+ IRValue* base,
+ UInt elementCount,
+ IRValue* const* elementIndices)
+ {
+ auto inst = createInstWithTrailingArgs<IRSwizzle>(
+ this,
+ kIROp_swizzle,
+ type,
+ base,
+ elementCount,
+ elementIndices);
+
+ addInst(inst);
+ return inst;
+ }
+
+ IRInst* IRBuilder::emitSwizzle(
+ IRType* type,
+ IRValue* base,
+ UInt elementCount,
+ UInt const* elementIndices)
+ {
+ auto intType = getBaseType(BaseType::Int);
+
+ IRValue* irElementIndices[4];
+ for (UInt ii = 0; ii < elementCount; ++ii)
+ {
+ irElementIndices[ii] = getIntValue(intType, elementIndices[ii]);
+ }
+
+ return emitSwizzle(type, base, elementCount, irElementIndices);
+ }
+
+
+ IRInst* IRBuilder::emitSwizzleSet(
+ IRType* type,
+ IRValue* base,
+ IRValue* source,
+ UInt elementCount,
+ IRValue* const* elementIndices)
+ {
+ IRValue* fixedArgs[] = { base, source };
+ UInt fixedArgCount = sizeof(fixedArgs) / sizeof(fixedArgs[0]);
+
+ auto inst = createInstWithTrailingArgs<IRSwizzleSet>(
+ this,
+ kIROp_swizzleSet,
+ type,
+ fixedArgCount,
+ fixedArgs,
+ elementCount,
+ elementIndices);
+
+ addInst(inst);
+ return inst;
+ }
+
+ IRInst* IRBuilder::emitSwizzleSet(
+ IRType* type,
+ IRValue* base,
+ IRValue* source,
+ UInt elementCount,
+ UInt const* elementIndices)
+ {
+ auto intType = getBaseType(BaseType::Int);
+
+ IRValue* irElementIndices[4];
+ for (UInt ii = 0; ii < elementCount; ++ii)
+ {
+ irElementIndices[ii] = getIntValue(intType, elementIndices[ii]);
+ }
+
+ return emitSwizzleSet(type, base, source, elementCount, irElementIndices);
+ }
+
IRInst* IRBuilder::emitReturn(
IRValue* val)
{
@@ -1018,6 +1210,133 @@ namespace Slang
return inst;
}
+ IRInst* IRBuilder::emitBranch(
+ IRBlock* block)
+ {
+ auto inst = createInst<IRUnconditionalBranch>(
+ this,
+ kIROp_unconditionalBranch,
+ getVoidType(),
+ block);
+ addInst(inst);
+ return inst;
+ }
+
+ IRInst* IRBuilder::emitBreak(
+ IRBlock* target)
+ {
+ auto inst = createInst<IRBreak>(
+ this,
+ kIROp_break,
+ getVoidType(),
+ target);
+ addInst(inst);
+ return inst;
+ }
+
+ IRInst* IRBuilder::emitContinue(
+ IRBlock* target)
+ {
+ auto inst = createInst<IRContinue>(
+ this,
+ kIROp_continue,
+ getVoidType(),
+ target);
+ addInst(inst);
+ return inst;
+ }
+
+ IRInst* IRBuilder::emitLoop(
+ IRBlock* target,
+ IRBlock* breakBlock,
+ IRBlock* continueBlock)
+ {
+ IRInst* args[] = { target, breakBlock, continueBlock };
+ UInt argCount = sizeof(args) / sizeof(args[0]);
+
+ auto inst = createInst<IRLoop>(
+ this,
+ kIROp_loop,
+ getVoidType(),
+ argCount,
+ args);
+ addInst(inst);
+ return inst;
+ }
+
+ IRInst* IRBuilder::emitBranch(
+ IRValue* val,
+ IRBlock* trueBlock,
+ IRBlock* falseBlock)
+ {
+ IRInst* args[] = { val, trueBlock, falseBlock };
+ UInt argCount = sizeof(args) / sizeof(args[0]);
+
+ auto inst = createInst<IRConditionalBranch>(
+ this,
+ kIROp_conditionalBranch,
+ getVoidType(),
+ argCount,
+ args);
+ addInst(inst);
+ return inst;
+ }
+
+ IRInst* IRBuilder::emitIf(
+ IRValue* val,
+ IRBlock* trueBlock,
+ IRBlock* afterBlock)
+ {
+ IRInst* args[] = { val, trueBlock, afterBlock };
+ UInt argCount = sizeof(args) / sizeof(args[0]);
+
+ auto inst = createInst<IRIf>(
+ this,
+ kIROp_if,
+ getVoidType(),
+ argCount,
+ args);
+ addInst(inst);
+ return inst;
+ }
+
+ IRInst* IRBuilder::emitIfElse(
+ IRValue* val,
+ IRBlock* trueBlock,
+ IRBlock* falseBlock,
+ IRBlock* afterBlock)
+ {
+ IRInst* args[] = { val, trueBlock, falseBlock, afterBlock };
+ UInt argCount = sizeof(args) / sizeof(args[0]);
+
+ auto inst = createInst<IRIfElse>(
+ this,
+ kIROp_ifElse,
+ getVoidType(),
+ argCount,
+ args);
+ addInst(inst);
+ return inst;
+ }
+
+ IRInst* IRBuilder::emitLoopTest(
+ IRValue* val,
+ IRBlock* bodyBlock,
+ IRBlock* breakBlock)
+ {
+ IRInst* args[] = { val, bodyBlock, breakBlock };
+ UInt argCount = sizeof(args) / sizeof(args[0]);
+
+ auto inst = createInst<IRLoopTest>(
+ this,
+ kIROp_loopTest,
+ getVoidType(),
+ argCount,
+ args);
+ addInst(inst);
+ return inst;
+ }
+
IRDecoration* IRBuilder::addDecorationImpl(
IRInst* inst,
UInt decorationSize,
@@ -1053,7 +1372,7 @@ namespace Slang
struct IRDumpContext
{
- FILE* file;
+ StringBuilder* builder;
int indent;
};
@@ -1061,14 +1380,16 @@ namespace Slang
IRDumpContext* context,
char const* text)
{
- fprintf(context->file, "%s", text);
+ context->builder->append(text);
}
static void dump(
IRDumpContext* context,
UInt val)
{
- fprintf(context->file, "%llu", (unsigned long long)val);
+ context->builder->append(val);
+
+// fprintf(context->file, "%llu", (unsigned long long)val);
}
static void dumpIndent(
@@ -1076,7 +1397,7 @@ namespace Slang
{
for (int ii = 0; ii < context->indent; ++ii)
{
- dump(context, " ");
+ dump(context, "\t");
}
}
@@ -1088,79 +1409,318 @@ namespace Slang
{
dump(context, "<null>");
}
- else
+ else if(inst->id)
{
dump(context, "%");
dump(context, inst->id);
}
+ else
+ {
+ dump(context, "_");
+ }
}
- static void dumpInst(
+ static void dumpType(
+ IRDumpContext* context,
+ IRType* type);
+
+ static void dumpOperand(
IRDumpContext* context,
IRInst* inst)
{
- dumpIndent(context);
if (!inst)
{
- dump(context, "<null>");
+ dump(context, "undef");
+ return;
}
- // TODO: need to display a name for the result...
+ switch (inst->op)
+ {
+ case kIROp_IntLit:
+ dump(context, ((IRConstant*)inst)->u.intVal);
+ return;
- auto op = inst->op;
- auto opInfo = &kIROpInfos[op];
+ case kIROp_FloatLit:
+ dump(context, ((IRConstant*)inst)->u.floatVal);
+ return;
- if (inst->id)
+ case kIROp_boolConst:
+ dump(context, ((IRConstant*)inst)->u.intVal ? "true" : "false");
+ return;
+
+ case kIROp_TypeType:
+ dumpType(context, (IRType*)inst);
+ return;
+
+ default:
+ break;
+ }
+
+ auto type = inst->getType();
+ if (type)
{
- dumpID(context, inst);
- dump(context, " = ");
+ switch (type->op)
+ {
+ case kIROp_TypeType:
+ dumpType(context, (IRType*)inst);
+ return;
+
+ default:
+ break;
+ }
}
- dump(context, opInfo->name);
- // TODO: dump operands
- uint32_t argCount = inst->argCount;
- for (uint32_t ii = 0; ii < argCount; ++ii)
+ dumpID(context, inst);
+ }
+
+ static void dumpType(
+ IRDumpContext* context,
+ IRType* type)
+ {
+ if (!type)
{
- if (ii != 0)
- dump(context, ", ");
- else
+ dumpID(context, type);
+ return;
+ }
+
+ auto op = type->op;
+ auto opInfo = kIROpInfos[op];
+
+ switch (op)
+ {
+ case kIROp_StructType:
+ dumpID(context, type);
+ break;
+
+ default:
{
- dump(context, " ");
+ dump(context, opInfo.name);
+ UInt argCount = type->getArgCount();
+
+ if (argCount > 1)
+ {
+ dump(context, "<");
+ for (UInt aa = 1; aa < argCount; ++aa)
+ {
+ if (aa != 1) dump(context, ",");
+ dumpOperand(context, type->getArg(aa));
+
+ }
+ dump(context, ">");
+ }
}
+ break;
+ }
+ }
- auto argVal = inst->getArgs()[ii].usedValue;
+ static void dumpInstTypeClause(
+ IRDumpContext* context,
+ IRType* type)
+ {
+ dump(context, "\t: ");
+ dumpType(context, type);
- // TODO: actually print the damn operand...
+ }
- dumpID(context, argVal);
- }
+ static void dumpInst(
+ IRDumpContext* context,
+ IRInst* inst);
- dump(context, "\n");
+ static void dumpChildrenRaw(
+ IRDumpContext* context,
+ IRParentInst* parent)
+ {
+ for (auto ii = parent->firstChild; ii; ii = ii->nextInst)
+ {
+ dumpInst(context, ii);
+ }
+ }
+ static void dumpChildren(
+ IRDumpContext* context,
+ IRInst* inst)
+ {
+ auto op = inst->op;
+ auto opInfo = &kIROpInfos[op];
if (opInfo->flags & kIROpFlag_Parent)
{
dumpIndent(context);
dump(context, "{\n");
context->indent++;
auto parent = (IRParentInst*)inst;
- for (auto ii = parent->firstChild; ii; ii = ii->nextInst)
- {
- dumpInst(context, ii);
- }
+ dumpChildrenRaw(context, parent);
context->indent--;
dumpIndent(context);
dump(context, "}\n");
}
}
+ static void dumpInst(
+ IRDumpContext* context,
+ IRInst* inst)
+ {
+ if (!inst)
+ {
+ dumpIndent(context);
+ dump(context, "<null>");
+ return;
+ }
+
+ auto op = inst->op;
+
+ // There are several ops we want to special-case here,
+ // so that they will be more pleasant to look at.
+ //
+ switch (op)
+ {
+ case kIROp_Module:
+ dumpIndent(context);
+ dump(context, "module\n");
+ dumpChildren(context, inst);
+ return;
+
+ case kIROp_Func:
+ {
+ IRFunc* func = (IRFunc*)inst;
+ dump(context, "\n");
+ dumpIndent(context);
+ dump(context, "func ");
+ dumpID(context, func);
+ dump(context, "(\n");
+ context->indent++;
+ for (auto pp = func->getFirstParam(); pp; pp = pp->getNextParam())
+ {
+ if (pp != func->getFirstParam())
+ dump(context, ",\n");
+
+ dumpIndent(context);
+ dump(context, "param ");
+ dumpID(context, pp);
+ dumpInstTypeClause(context, pp->getType());
+ }
+ context->indent--;
+ dump(context, ")\n");
+
+ dumpIndent(context);
+ dump(context, "{\n");
+ context->indent++;
+
+ for (auto bb = func->getFirstBlock(); bb; bb = bb->getNextBlock())
+ {
+ if (bb != func->getFirstBlock())
+ dump(context, "\n");
+ dumpInst(context, bb);
+ }
+
+ context->indent--;
+ dump(context, "}\n");
+ }
+ return;
+
+ case kIROp_TypeType:
+ case kIROp_Param:
+ case kIROp_IntLit:
+ case kIROp_FloatLit:
+ case kIROp_boolConst:
+ // Don't dump here
+ return;
+
+ case kIROp_Block:
+ {
+ IRBlock* block = (IRBlock*)inst;
+
+ context->indent--;
+ dump(context, "block ");
+ dumpID(context, block);
+ dump(context, ":\n");
+ context->indent++;
+
+ dumpChildrenRaw(context, block);
+ }
+ return;
+
+ default:
+ break;
+ }
- void dumpIR(IRModule* module)
+ // We also want to special-case based on the *type*
+ // of the instruction
+ auto type = inst->getType();
+ if (type && type->op == kIROp_TypeType)
+ {
+ // We probably don't want to print most types
+ // when producing "friendly" output.
+ switch (type->op)
+ {
+ case kIROp_StructType:
+ break;
+
+ default:
+ return;
+ }
+ }
+
+
+ // Okay, we have a seemingly "ordinary" op now
+ dumpIndent(context);
+
+ auto opInfo = &kIROpInfos[op];
+
+ if (type && type->op == kIROp_TypeType)
+ {
+ dump(context, "type ");
+ dumpID(context, inst);
+ dump(context, "\t= ");
+ }
+ else if (type && type->op == kIROp_VoidType)
+ {
+ }
+ else
+ {
+ dump(context, "let ");
+ dumpID(context, inst);
+ dumpInstTypeClause(context, type);
+ dump(context, "\t= ");
+ }
+
+
+ dump(context, opInfo->name);
+
+ uint32_t argCount = inst->argCount;
+ dump(context, "(");
+ for (uint32_t ii = 1; ii < argCount; ++ii)
+ {
+ if (ii != 1)
+ dump(context, ", ");
+
+ auto argVal = inst->getArgs()[ii].usedValue;
+
+ dumpOperand(context, argVal);
+ }
+ dump(context, ")");
+
+ dump(context, "\n");
+
+ // The instruction might have children,
+ // so we need to handle those here
+ dumpChildren(context, inst);
+ }
+
+ void printSlangIRAssembly(StringBuilder& builder, IRModule* module)
{
IRDumpContext context;
- context.file = stderr;
+ context.builder = &builder;
context.indent = 0;
- dumpInst(&context, module);
+ dumpChildrenRaw(&context, module);
}
+ String getSlangIRAssembly(IRModule* module)
+ {
+ StringBuilder sb;
+ printSlangIRAssembly(sb, module);
+ return sb;
+ }
+
+
}
diff --git a/source/slang/ir.h b/source/slang/ir.h
index 6e4a25fe1..9b271b09b 100644
--- a/source/slang/ir.h
+++ b/source/slang/ir.h
@@ -9,7 +9,6 @@
#include "../core/basic.h"
-
namespace Slang {
// TODO(tfoley): We should ditch this enumeration
@@ -143,6 +142,9 @@ enum IRDecorationOp : uint16_t
{
kIRDecorationOp_HighLevelDecl,
kIRDecorationOp_Layout,
+ kIRDecorationOp_EntryPoint,
+ kIRDecorationOp_ComputeThreadGroupSize,
+ kIRDecorationOp_LoopControl,
};
// A "decoration" that gets applied to an instruction.
@@ -220,27 +222,6 @@ struct IRInst
// All existing uses of `IRValue` should move to `IRInst`
typedef IRInst IRValue;
-class Decl;
-
-// Associates an IR-level decoration with a source declaration
-// in the high-level AST, that can be used to extract
-// additional information that informs code emission.
-struct IRHighLevelDeclDecoration : IRDecoration
-{
- enum { kDecorationOp = kIRDecorationOp_HighLevelDecl };
-
- Decl* decl;
-};
-
-// Associates an IR-level decoration with a source layout
-struct IRLayoutDecoration : IRDecoration
-{
- enum { kDecorationOp = kIRDecorationOp_Layout };
-
- Layout* layout;
-};
-
-
typedef long long IRIntegerValue;
typedef double IRFloatingPointValue;
@@ -266,126 +247,13 @@ struct IRType : IRInst
{
};
-struct IRVectorType : IRType
-{
- IRUse elementType;
- IRUse elementCount;
-
- IRType* getElementType() { return (IRType*) elementType.usedValue; }
- IRInst* getElementCount() { return elementCount.usedValue; }
-};
-
-struct IRMatrixType : IRType
-{
- IRUse elementType;
- IRUse rowCount;
- IRUse columnCount;
-
- IRType* getElementType() { return (IRType*) elementType.usedValue; }
- IRInst* getRowCount() { return rowCount.usedValue; }
- IRInst* getColumnCount() { return columnCount.usedValue; }
-};
-
-struct IRFuncType : IRType
-{
- IRUse resultType;
- // parameter tyeps are varargs...
-
- IRType* getResultType() { return (IRType*) resultType.usedValue; }
- UInt getParamCount()
- {
- return getArgCount() - 2;
- }
- IRType* getParamType(UInt index)
- {
- return (IRType*) getArg(2 + index);
- }
-};
-
-struct IRPtrType : IRType
-{
- IRUse valueType;
-
- IRType* getValueType() { return (IRType*) valueType.usedValue; }
-};
-
-struct IRTextureType : IRType
-{
- IRUse flavor;
- IRUse elementType;
-
- IRIntegerValue getFlavor() { return ((IRConstant*) flavor.usedValue)->u.intVal; }
- IRType* getElementType() { return (IRType*) elementType.usedValue; }
-};
-
-struct IRBufferType : IRType
-{
- IRUse elementType;
- IRType* getElementType() { return (IRType*) elementType.usedValue; }
-};
-
-
-struct IRUniformBufferType : IRType
-{
- IRUse elementType;
- IRType* getElementType() { return (IRType*) elementType.usedValue; }
-};
-
-struct IRConstantBufferType : IRUniformBufferType {};
-struct IRTextureBufferType : IRUniformBufferType {};
-
-struct IRCall : IRInst
-{
- IRUse func;
-};
-
-struct IRLoad : IRInst
-{
- IRUse ptr;
-};
-
-struct IRStore : IRInst
-{
- IRUse ptr;
- IRUse val;
-};
-
-struct IRStructField;
-struct IRFieldExtract : IRInst
-{
- IRUse base;
- IRUse field;
-
- IRInst* getBase() { return base.usedValue; }
- IRStructField* getField() { return (IRStructField*) field.usedValue; }
-};
-
-struct IRFieldAddress : IRInst
-{
- IRUse base;
- IRUse field;
-
- IRInst* getBase() { return base.usedValue; }
- IRStructField* getField() { return (IRStructField*) field.usedValue; }
-};
-
-
// A instruction that ends a basic block (usually because of control flow)
struct IRTerminatorInst : IRInst
{};
-struct IRReturn : IRTerminatorInst
-{};
-
-struct IRReturnVal : IRReturn
-{
- IRUse val;
-
- IRInst* getVal() { return val.usedValue; }
-};
+bool isTerminatorInst(IROp op);
+bool isTerminatorInst(IRInst* inst);
-struct IRReturnVoid : IRReturn
-{};
// A parent instruction contains a sequence of other instructions
//
@@ -398,20 +266,6 @@ struct IRParentInst : IRInst
IRInst* lastChild;
};
-struct IRStructField : IRInst
-{
- IRType* getFieldType() { return (IRType*) type.usedValue; }
-
- IRStructField* getNextField() { return (IRStructField*) nextInst; }
-};
-
-struct IRStructDecl : IRParentInst
-{
- IRStructField* getFirstField() { return (IRStructField*) firstChild; }
- IRStructField* getLastField() { return (IRStructField*) lastChild; }
-};
-
-
// A basic block is a parent instruction that adds the constraint
// that all the children need to be "ordinary" instructions (so
// no function declarations, or nested blocks). We also expect
@@ -425,6 +279,8 @@ struct IRBlock : IRParentInst
IRBlock* getPrevBlock() { return (IRBlock*) prevInst; }
IRBlock* getNextBlock() { return (IRBlock*) nextInst; }
+
+ IRFunc* getParent() { return (IRFunc*)parent; }
};
// A function parameter is represented by an instruction
@@ -434,8 +290,7 @@ struct IRParam : IRInst
IRParam* getNextParam();
};
-struct IRVar : IRInst
-{};
+struct IRFuncType;
// A function is a parent to zero or more blocks of instructions.
//
@@ -445,9 +300,9 @@ struct IRFunc : IRParentInst
{
IRFuncType* getType() { return (IRFuncType*) type.usedValue; }
- IRType* getResultType() { return getType()->getResultType(); }
- UInt getParamCount() { return getType()->getParamCount(); }
- IRType* getParamType(UInt index) { return getType()->getParamType(index); }
+ IRType* getResultType();
+ UInt getParamCount();
+ IRType* getParamType(UInt index);
IRBlock* getFirstBlock() { return (IRBlock*) firstChild; }
IRBlock* getLastBlock() { return (IRBlock*) lastChild; }
@@ -465,163 +320,10 @@ struct IRModule : IRParentInst
IRInstID idCounter;
};
-// Description of an instruction to be used for global value numbering
-struct IRInstKey
-{
- IRInst* inst;
-
- int GetHashCode();
-};
-
-bool operator==(IRInstKey const& left, IRInstKey const& right);
-
-struct IRConstantKey
-{
- IRConstant* inst;
-
- int GetHashCode();
-};
-bool operator==(IRConstantKey const& left, IRConstantKey const& right);
-
-struct SharedIRBuilder
-{
- // The module that will own all of the IR
- IRModule* module;
-
- Dictionary<IRInstKey, IRInst*> globalValueNumberingMap;
- Dictionary<IRConstantKey, IRConstant*> constantMap;
-};
-
-struct IRBuilder
-{
- // Shared state for all IR builders working on the same module
- SharedIRBuilder* shared;
-
- IRModule* getModule() { return shared->module; }
-
- // The parent instruction to add children to.
- IRParentInst* parentInst;
-
- void addInst(IRParentInst* parent, IRInst* inst);
- void addInst(IRInst* inst);
-
- IRType* getBaseType(BaseType flavor);
- IRType* getBoolType();
- IRType* getVectorType(IRType* elementType, IRValue* elementCount);
- IRType* getMatrixType(
- IRType* elementType,
- IRValue* rowCount,
- IRValue* columnCount);
- IRType* getTypeType();
- IRType* getVoidType();
- IRType* getBlockType();
-
- IRType* getIntrinsicType(
- IROp op,
- UInt argCount,
- IRValue* const* args);
-
- IRStructDecl* createStructType();
- IRStructField* createStructField(IRType* fieldType);
-
- IRType* getFuncType(
- UInt paramCount,
- IRType* const* paramTypes,
- IRType* resultType);
-
- IRType* getPtrType(
- IRType* valueType);
-
- IRValue* getBoolValue(bool value);
- IRValue* getIntValue(IRType* type, IRIntegerValue value);
- IRValue* getFloatValue(IRType* type, IRFloatingPointValue value);
-
- IRInst* emitCallInst(
- IRType* type,
- IRValue* func,
- UInt argCount,
- IRValue* const* args);
-
- IRInst* emitIntrinsicInst(
- IRType* type,
- IROp op,
- UInt argCount,
- IRValue* const* args);
-
- IRInst* emitConstructorInst(
- IRType* type,
- UInt argCount,
- IRValue* const* args);
-
- IRModule* createModule();
-
- IRFunc* createFunc();
-
- IRBlock* createBlock();
- IRBlock* emitBlock();
-
- IRParam* emitParam(
- IRType* type);
-
- IRVar* emitVar(
- IRType* type);
-
- IRInst* emitLoad(
- IRValue* ptr);
-
- IRInst* emitStore(
- IRValue* dstPtr,
- IRValue* srcVal);
-
- IRInst* emitFieldExtract(
- IRType* type,
- IRValue* base,
- IRStructField* field);
-
- IRInst* emitFieldAddress(
- IRType* type,
- IRValue* basePtr,
- IRStructField* field);
-
- IRInst* emitElementExtract(
- IRType* type,
- IRValue* base,
- IRValue* index);
-
- IRInst* emitElementAddress(
- IRType* type,
- IRValue* basePtr,
- IRValue* index);
-
-
- IRInst* emitReturn(
- IRValue* val);
-
- IRInst* emitReturn();
-
- IRDecoration* addDecorationImpl(
- IRInst* inst,
- UInt decorationSize,
- IRDecorationOp op);
-
- template<typename T>
- T* addDecoration(IRInst* inst, IRDecorationOp op)
- {
- return (T*) addDecorationImpl(inst, sizeof(T), op);
- }
-
- template<typename T>
- T* addDecoration(IRInst* inst)
- {
- return (T*) addDecorationImpl(inst, sizeof(T), IRDecorationOp(T::kDecorationOp));
- }
-
- IRHighLevelDeclDecoration* addHighLevelDeclDecoration(IRInst* inst, Decl* decl);
- IRLayoutDecoration* addLayoutDecoration(IRInst* inst, Layout* layout);
-};
-
-void dumpIR(IRModule* module);
+void printSlangIRAssembly(StringBuilder& builder, IRModule* module);
+String getSlangIRAssembly(IRModule* module);
}
+
#endif
diff --git a/source/slang/lower-to-ir.cpp b/source/slang/lower-to-ir.cpp
index 10b4aefca..5ee6d5460 100644
--- a/source/slang/lower-to-ir.cpp
+++ b/source/slang/lower-to-ir.cpp
@@ -2,44 +2,130 @@
#include "lower-to-ir.h"
#include "ir.h"
+#include "ir-insts.h"
#include "type-layout.h"
#include "visitor.h"
namespace Slang
{
-struct BoundMemberInfo;
+// This file implements lowering of the Slang AST to a simpler SSA
+// intermediate representation.
+//
+// IR is generated in a context (`IRGenContext`), which tracks the current
+// location in the IR where code should be emitted (e.g., what basic
+// block to add instructions to). Lowering a statement will emit some
+// number of instructions to the context, and possibly change the
+// insertion point (because of control flow).
+//
+// When lowering an expression we have a more interesting challenge, for
+// two main reasons:
+//
+// 1. There might be types that are representible in the AST, but which
+// we don't want to support natively in the IR. An example is a `struct`
+// type with both ordinary and resource-type members; we might want to
+// split values with such a type into distinct values during lowering.
+//
+// 2. We need to handle the difference between l-value and r-value expressions,
+// and in particular the fact that HLSL/Slang supports complicated sorts
+// of l-values (e.g., `someVector.zxy` is an l-value, even though it can't
+// be represented by a single pointer), and also allows l-values to appear
+// in multiple contexts (not just the left-hand side of assignment, but
+// also as an argument to match an `out` or `in out` parameter).
+//
+// Our solution to both of these problems is the same. Rather than having
+// the lowering of an expression return a single IR-level value (`IRInst*`),
+// we have it return a more complex type (`LoweredValInfo`) which can represent
+// a wider range of conceptual "values" which might correspond to multiple IR-level
+// values, and/or represent a pointer to an l-value rather than the r-value itself.
+
+// We want to keep the representation of a `LoweringValInfo` relatively light
+// - right now it is just a single pointer plus a "tag" to distinguish the cases.
+//
+// This means that cases that can't fit in a single pointer need a heap allocation
+// to store their payload. For simplicity we represent all of these with a class
+// hierarchy:
+//
+struct ExtendedValueInfo : RefObject
+{};
-struct SubscriptInfo : RefObject
+// This case is used to indicate a value that is a reference
+// to an AST-level subscript declaration.
+//
+struct SubscriptInfo : ExtendedValueInfo
{
DeclRef<SubscriptDecl> declRef;
};
-struct BoundSubscriptInfo : RefObject
+// This case is used to indicate a reference to an AST-level
+// subscript operation bound to particular arguments.
+//
+// For example in a case like this:
+//
+// RWStructuredBuffer<Foo> gBuffer;
+// ... gBuffer[someIndex] ...
+//
+// the expression `gBuffer[someIndex]` will be lowered to
+// a value that references `RWStructureBuffer<Foo>::operator[]`
+// with arguments `(gBuffer, someIndex)`.
+//
+// Such a value can be an l-value, and depending on the context
+// where it is used, can lower into a call to either the getter
+// or setter operations of the subscript.
+//
+struct BoundSubscriptInfo : ExtendedValueInfo
{
DeclRef<SubscriptDecl> declRef;
IRType* type;
List<IRValue*> args;
};
+// Some cases of `ExtendedValueInfo` need to
+// recursively contain `LoweredValInfo`s, and
+// so we forward declare them here and fill
+// them in later.
+//
+struct BoundMemberInfo;
+struct SwizzledLValueInfo;
+
+
+// This type is our core representation of lowered values.
+// In the simple case, it just wraps an `IRInst*`.
+// More complex cases, representing l-values or aggregate
+// values are also supported.
struct LoweredValInfo
{
+ // Which of the cases of value are we looking at?
enum class Flavor
{
+ // No value (akin to a null pointer)
None,
+
+ // A simple IR value
Simple,
+
+ // An l-value reprsented as an IR
+ // pointer to the value
Ptr,
+
+ // A member declaration bound to a particular `this` value
BoundMember,
+
+ // A reference to an AST-level subscript operation
Subscript,
+
+ // An AST-level subscript operation bound to a particular
+ // object and arguments.
BoundSubscript,
+
+ // The result of applying swizzling to an l-value
+ SwizzledLValue,
};
union
{
IRValue* val;
- BoundMemberInfo* boundMemberInfo;
- SubscriptInfo* subscriptInfo;
- BoundSubscriptInfo* boundSubscriptInfo;
+ ExtendedValueInfo* ext;
};
Flavor flavor;
@@ -66,33 +152,87 @@ struct LoweredValInfo
}
static LoweredValInfo boundMember(
- LoweredValInfo const& base,
- LoweredValInfo const& member);
+ BoundMemberInfo* boundMemberInfo);
+
+ BoundMemberInfo* getBoundMemberInfo()
+ {
+ assert(flavor == Flavor::BoundMember);
+ return (BoundMemberInfo*)ext;
+ }
static LoweredValInfo subscript(
SubscriptInfo* subscriptInfo);
+ SubscriptInfo* getSubscriptInfo()
+ {
+ assert(flavor == Flavor::Subscript);
+ return (SubscriptInfo*)ext;
+ }
+
static LoweredValInfo boundSubscript(
BoundSubscriptInfo* boundSubscriptInfo);
+
+ BoundSubscriptInfo* getBoundSubscriptInfo()
+ {
+ assert(flavor == Flavor::BoundSubscript);
+ return (BoundSubscriptInfo*)ext;
+ }
+
+ static LoweredValInfo swizzledLValue(
+ SwizzledLValueInfo* extInfo);
+
+ SwizzledLValueInfo* getSwizzledLValueInfo()
+ {
+ assert(flavor == Flavor::SwizzledLValue);
+ return (SwizzledLValueInfo*)ext;
+ }
};
-struct BoundMemberInfo
+// Represents some declaration bound to a particular
+// object. For example, if we had `obj.f` where `f`
+// is a member function, we'd use a `BoundMemberInfo`
+// to represnet this.
+//
+// Note: This case is largely avoided by special-casing
+// in the handling of calls (like `obj.f(arg)`), but
+// it is being left here as an example of what we might
+// need/want to do in the long term.
+struct BoundMemberInfo : ExtendedValueInfo
{
+ // The base object
LoweredValInfo base;
- LoweredValInfo member;
+
+ // The (AST-level) declaration reference.
+ DeclRef<Decl> declRef;
};
-LoweredValInfo LoweredValInfo::boundMember(
- LoweredValInfo const& base,
- LoweredValInfo const& member)
+// Represents the result of a swizzle operation in
+// an l-value context. A swizzle without duplicate
+// elements is allowed as an l-value, even if the
+// element are non-contiguous (`.xz`) or out of
+// order (`.zxy`).
+//
+struct SwizzledLValueInfo : ExtendedValueInfo
{
- BoundMemberInfo* boundMember = new BoundMemberInfo();
- boundMember->base = base;
- boundMember->member = member;
+ // IR-level The type of the expression.
+ IRType* type;
+
+ // The base expression (this should be an l-value)
+ LoweredValInfo base;
+
+ // The number of elements in the swizzle
+ UInt elementCount;
+ // THe indices for the elements being swizzled
+ UInt elementIndices[4];
+};
+
+LoweredValInfo LoweredValInfo::boundMember(
+ BoundMemberInfo* boundMemberInfo)
+{
LoweredValInfo info;
info.flavor = Flavor::BoundMember;
- info.boundMemberInfo = boundMember;
+ info.ext = boundMemberInfo;
return info;
}
@@ -101,7 +241,7 @@ LoweredValInfo LoweredValInfo::subscript(
{
LoweredValInfo info;
info.flavor = Flavor::Subscript;
- info.subscriptInfo = subscriptInfo;
+ info.ext = subscriptInfo;
return info;
}
@@ -110,10 +250,18 @@ LoweredValInfo LoweredValInfo::boundSubscript(
{
LoweredValInfo info;
info.flavor = Flavor::BoundSubscript;
- info.boundSubscriptInfo = boundSubscriptInfo;
+ info.ext = boundSubscriptInfo;
return info;
}
+LoweredValInfo LoweredValInfo::swizzledLValue(
+ SwizzledLValueInfo* extInfo)
+{
+ LoweredValInfo info;
+ info.flavor = Flavor::SwizzledLValue;
+ info.ext = extInfo;
+ return info;
+}
struct SharedIRGenContext
{
@@ -123,8 +271,13 @@ struct SharedIRGenContext
Dictionary<DeclRef<Decl>, LoweredValInfo> declValues;
- // Arrays we keep around strictly for memory-management purposes
- List<RefPtr<BoundSubscriptInfo>> boundSubscripts;
+ // Arrays we keep around strictly for memory-management purposes:
+
+ // Any extended values created during lowering need
+ // to be cleaned up after the fact. We don't try
+ // to reference-count these along the way because
+ // they need to get stored into a `union` inside `LoweredValInfo`
+ List<RefPtr<ExtendedValueInfo>> extValues;
};
@@ -178,6 +331,29 @@ LoweredValInfo emitCallToVal(
}
}
+LoweredValInfo emitCompoundAssignOp(
+ IRGenContext* context,
+ IRType* type,
+ IROp op,
+ UInt argCount,
+ IRValue* const* args)
+{
+ auto builder = context->irBuilder;
+
+ assert(argCount == 2);
+ auto leftPtr = args[0];
+ auto rightVal = args[1];
+
+ auto leftVal = builder->emitLoad(leftPtr);
+
+ IRInst* innerArgs[] = { leftVal, rightVal };
+ auto innerOp = builder->emitIntrinsicInst(type, op, 2, innerArgs);
+
+ builder->emitStore(leftPtr, innerOp);
+
+ return LoweredValInfo::ptr(leftPtr);
+}
+
// Given a `DeclRef` for something callable, along with a bunch of
// arguments, emit an appropriate call to it.
LoweredValInfo emitCallToDeclRef(
@@ -229,7 +405,7 @@ LoweredValInfo emitCallToDeclRef(
boundSubscript->type = type;
boundSubscript->args.AddRange(args, argCount);
- context->shared->boundSubscripts.Add(boundSubscript);
+ context->shared->extValues.Add(boundSubscript);
return LoweredValInfo::boundSubscript(boundSubscript);
}
@@ -245,6 +421,35 @@ LoweredValInfo emitCallToDeclRef(
{
auto op = getIntrinsicOp(funcDecl, intrinsicOpModifier);
+ if (Int(op) < 0)
+ {
+ switch (op)
+ {
+ case kIRPseudoOp_Pos:
+ return LoweredValInfo::simple(args[0]);
+
+#define CASE(COMPOUND, OP) \
+ case COMPOUND: return emitCompoundAssignOp(context, type, OP, argCount, args)
+
+ CASE(kIRPseudoOp_AddAssign, kIROp_Add);
+ CASE(kIRPseudoOp_SubAssign, kIROp_Sub);
+ CASE(kIRPseudoOp_MulAssign, kIROp_Mul);
+ CASE(kIRPseudoOp_DivAssign, kIROp_Div);
+ CASE(kIRPseudoOp_ModAssign, kIROp_Mod);
+ CASE(kIRPseudoOp_AndAssign, kIROp_BitAnd);
+ CASE(kIRPseudoOp_OrAssign, kIROp_BitOr);
+ CASE(kIRPseudoOp_XorAssign, kIROp_BitXor);
+ CASE(kIRPseudoOp_LshAssign, kIROp_Lsh);
+ CASE(kIRPseudoOp_RshAssign, kIROp_Rsh);
+
+#undef CASE
+
+ default:
+ SLANG_UNIMPLEMENTED_X("IR pseudo-op");
+ break;
+ }
+ }
+
return LoweredValInfo::simple(builder->emitIntrinsicInst(
type,
op,
@@ -277,6 +482,8 @@ LoweredValInfo emitCallToDeclRef(
IRValue* getSimpleVal(IRGenContext* context, LoweredValInfo lowered)
{
+ auto builder = context->irBuilder;
+
top:
switch(lowered.flavor)
{
@@ -287,12 +494,11 @@ top:
return lowered.val;
case LoweredValInfo::Flavor::Ptr:
- return context->irBuilder->emitLoad(lowered.val);
+ return builder->emitLoad(lowered.val);
case LoweredValInfo::Flavor::BoundSubscript:
{
- auto boundSubscriptInfo = lowered.boundSubscriptInfo;
- auto builder = context->irBuilder;
+ auto boundSubscriptInfo = lowered.getBoundSubscriptInfo();
for (auto getter : getMembersOfType<GetterDecl>(boundSubscriptInfo->declRef))
{
@@ -309,6 +515,17 @@ top:
}
break;
+ case LoweredValInfo::Flavor::SwizzledLValue:
+ {
+ auto swizzleInfo = lowered.getSwizzledLValueInfo();
+
+ return builder->emitSwizzle(
+ swizzleInfo->type,
+ getSimpleVal(context, swizzleInfo->base),
+ swizzleInfo->elementCount,
+ swizzleInfo->elementIndices);
+ }
+
default:
SLANG_UNEXPECTED("unhandled value flavor");
return nullptr;
@@ -397,8 +614,11 @@ IRType* lowerSimpleType(
return getSimpleType(lowered);
}
+LoweredValInfo lowerLValueExpr(
+ IRGenContext* context,
+ Expr* expr);
-LoweredValInfo lowerExpr(
+LoweredValInfo lowerRValueExpr(
IRGenContext* context,
Expr* expr);
@@ -528,6 +748,38 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
return getBuilder()->getMatrixType(irElementType, irRowCount, irColumnCount);
}
+
+ LoweredTypeInfo getArrayType(
+ LoweredTypeInfo const& loweredElementType,
+ IRValue* irElementCount)
+ {
+ switch (loweredElementType.flavor)
+ {
+ case LoweredTypeInfo::Flavor::Simple:
+ return getBuilder()->getArrayType(
+ loweredElementType.type,
+ irElementCount);
+ break;
+
+ default:
+ SLANG_UNEXPECTED("array element type");
+ break;
+ }
+ }
+
+ LoweredTypeInfo visitArrayExpressionType(ArrayExpressionType* type)
+ {
+ auto loweredElementType = lowerType(context, type->BaseType);
+ if (auto elementCount = type->ArrayLength)
+ {
+ auto irElementCount = lowerSimpleVal(context, elementCount);
+ return getArrayType(loweredElementType, irElementCount);
+ }
+ else
+ {
+ return getArrayType(loweredElementType, nullptr);
+ }
+ }
};
LoweredValInfo lowerVal(
@@ -556,15 +808,79 @@ struct LoweringVisitor
, ValVisitor<LoweringVisitor, RefPtr<Val>, RefPtr<Type>>
#endif
+LoweredValInfo createVar(
+ IRGenContext* context,
+ LoweredTypeInfo type,
+ Decl* decl = nullptr,
+ Layout* layout = nullptr,
+ IRAddressSpace addressSpace = kIRAddressSpace_Default)
+{
+ auto builder = context->irBuilder;
+ switch( type.flavor )
+ {
+ case LoweredTypeInfo::Flavor::Simple:
+ {
+ auto irAlloc = builder->emitVar(getSimpleType(type), addressSpace);
+
+ if (decl)
+ {
+ builder->addHighLevelDeclDecoration(irAlloc, decl);
+ }
+
+ if (layout)
+ {
+ builder->addLayoutDecoration(irAlloc, layout);
+ }
+
+
+ return LoweredValInfo::ptr(irAlloc);
+ }
+ break;
+
+ default:
+ SLANG_UNIMPLEMENTED_X("var type");
+ return LoweredValInfo();
+ }
+
+}
+
+void addArgs(
+ IRGenContext* context,
+ List<IRValue*>* ioArgs,
+ LoweredValInfo argInfo)
+{
+ auto& args = *ioArgs;
+ switch( argInfo.flavor )
+ {
+ case LoweredValInfo::Flavor::Simple:
+ case LoweredValInfo::Flavor::Ptr:
+ case LoweredValInfo::Flavor::SwizzledLValue:
+ args.Add(getSimpleVal(context, argInfo));
+ break;
+
+ default:
+ SLANG_UNIMPLEMENTED_X("addArgs case");
+ break;
+ }
+}
//
-struct ExprLoweringVisitor : ExprVisitor<ExprLoweringVisitor, LoweredValInfo>
+template<typename Derived>
+struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
{
IRGenContext* context;
IRBuilder* getBuilder() { return context->irBuilder; }
+ // Lower an expression that should have the same l-value-ness
+ // as the visitor itself.
+ LoweredValInfo lowerSubExpr(Expr* expr)
+ {
+ return dispatch(expr);
+ }
+
+
LoweredValInfo visitVarExpr(VarExpr* expr)
{
LoweredValInfo info = ensureDecl(context, expr->declRef);
@@ -576,6 +892,83 @@ struct ExprLoweringVisitor : ExprVisitor<ExprLoweringVisitor, LoweredValInfo>
SLANG_UNEXPECTED("overloaded expressions should not occur in checked AST");
}
+ LoweredValInfo visitIndexExpr(IndexExpr* expr)
+ {
+ auto type = lowerType(context, expr->type);
+ auto baseVal = lowerSubExpr(expr->BaseExpression);
+ auto indexVal = getSimpleVal(context, lowerRValueExpr(context, expr->IndexExpression));
+
+ return subscriptValue(type, baseVal, indexVal);
+ }
+
+ LoweredValInfo visitMemberExpr(MemberExpr* expr)
+ {
+ auto loweredType = lowerType(context, expr->type);
+ auto loweredBase = lowerRValueExpr(context, expr->BaseExpression);
+
+ auto declRef = expr->declRef;
+ if (auto fieldDeclRef = declRef.As<StructField>())
+ {
+ // Okay, easy enough: we have a reference to a field of a struct type...
+
+ auto loweredField = ensureDecl(context, fieldDeclRef);
+ return extractField(loweredType, loweredBase, loweredField);
+ }
+ else if (auto callableDeclRef = declRef.As<CallableDecl>())
+ {
+ RefPtr<BoundMemberInfo> boundMemberInfo = new BoundMemberInfo();
+ boundMemberInfo->base = loweredBase;
+ boundMemberInfo->declRef = callableDeclRef;
+ return LoweredValInfo::boundMember(boundMemberInfo);
+ }
+
+ SLANG_UNIMPLEMENTED_X("codegen for subscript expression");
+ }
+
+ // We will always lower a dereference expression (`*ptr`)
+ // as an l-value, since that is the easiest way to handle it.
+ LoweredValInfo visitDerefExpr(DerefExpr* expr)
+ {
+ auto loweredType = lowerType(context, expr->type);
+ auto loweredBase = lowerRValueExpr(context, expr->base);
+
+ // TODO: handle tupel-type for `base`
+
+ // The type of the lowered base must by some kind of pointer,
+ // in order for a dereference to make senese, so we just
+ // need to extract the value type from that pointer here.
+ //
+ auto loweredBaseVal = getSimpleVal(context, loweredBase);
+ auto loweredBaseType = loweredBaseVal->getType();
+ switch( loweredBaseType->op )
+ {
+ case kIROp_PtrType:
+ // TODO: should we enumerate these explicitly?
+ case kIROp_ConstantBufferType:
+ case kIROp_TextureBufferType:
+ // Note that we do *not* perform an actual `load` operation
+ // here, but rather just use the pointer value to construct
+ // an appropriate `LoweredValInfo` representing the underlying
+ // dereference.
+ //
+ // This is important so that an expression like `&((*foo).bar)`
+ // (which is desugared from `&foo->bar`) can be handled; such
+ // an expression does *not* perform a dereference at runtime,
+ // and is just a bit of pointer math.
+ //
+ return LoweredValInfo::ptr(loweredBaseVal);
+
+ default:
+ SLANG_UNIMPLEMENTED_X("codegen for deref expression");
+ return LoweredValInfo();
+ }
+ }
+
+ LoweredValInfo visitParenExpr(ParenExpr* expr)
+ {
+ return lowerSubExpr(expr->base);
+ }
+
LoweredValInfo visitInitializerListExpr(InitializerListExpr* expr)
{
SLANG_UNIMPLEMENTED_X("codegen for initializer list expression");
@@ -605,23 +998,6 @@ struct ExprLoweringVisitor : ExprVisitor<ExprLoweringVisitor, LoweredValInfo>
SLANG_UNIMPLEMENTED_X("codegen for aggregate type constructor expression");
}
- void addArgs(List<IRValue*>* ioArgs, LoweredValInfo argInfo)
- {
- auto& args = *ioArgs;
- switch( argInfo.flavor )
- {
- case LoweredValInfo::Flavor::Simple:
- case LoweredValInfo::Flavor::Ptr:
- args.Add(getSimpleVal(context, argInfo));
- break;
-
- default:
- SLANG_UNIMPLEMENTED_X("addArgs case");
- break;
- }
- }
-
-
// Add arguments that appeared directly in an argument list
// to the list of argument values for a call.
void addDirectCallArgs(
@@ -632,45 +1008,129 @@ struct ExprLoweringVisitor : ExprVisitor<ExprLoweringVisitor, LoweredValInfo>
for( auto arg : expr->Arguments )
{
- auto loweredArg = lowerExpr(context, arg);
- addArgs(&irArgs, loweredArg);
+ // TODO: Need to handle case of l-value arguments,
+ // when they are matched to `out` or `in out` parameters.
+ auto loweredArg = lowerRValueExpr(context, arg);
+ addArgs(context, ioArgs, loweredArg);
}
}
- // Try to add "all" the arguments for a call to the argument list,
- // including implicit arguments that come from (e.g.,) a member
- // expression used to form the call.
- void addCallArgs(
- InvokeExpr* expr,
- List<IRValue*>* ioArgs)
- {
- auto& irArgs = *ioArgs;
+ // After a call to a function with `out` or `in out`
+ // parameters, we may need to copy data back into
+ // the l-value locations used for output arguments.
+ //
+ // During lowering of the argument list, we build
+ // up a list of these "fixup" assignments that need
+ // to be performed.
+ struct OutArgumentFixup
+ {
+ LoweredValInfo dst;
+ LoweredValInfo src;
+ };
- // TODO: should unwrap any layers of identity expressions around this...
- if( auto baseMemberExpr = expr->FunctionExpr.As<MemberExpr>() )
+ void addDirectCallArgs(
+ InvokeExpr* expr,
+ DeclRef<CallableDecl> funcDeclRef,
+ List<IRValue*>* ioArgs,
+ List<OutArgumentFixup>* ioFixups)
+ {
+ auto funcDecl = funcDeclRef.getDecl();
+ auto& args = expr->Arguments;
+ UInt argCount = expr->Arguments.Count();
+ UInt argIndex = 0;
+ for (auto paramDeclRef : getMembersOfType<ParamDecl>(funcDeclRef))
{
- // This call took the form of a member function call, so
- // we need to correctly add the `this` argument as
- // an explicit argument.
- //
- auto loweredBase = lowerExpr(context, baseMemberExpr->BaseExpression);
- addArgs(&irArgs, loweredBase);
- }
+ if (argIndex >= argCount)
+ {
+ // The remaining parameters must be defaulted...
+ break;
+ }
- addDirectCallArgs(expr, ioArgs);
- }
+ auto paramDecl = paramDeclRef.getDecl();
+ auto paramType = lowerType(context, GetType(paramDeclRef));
+ auto argExpr = expr->Arguments[argIndex++];
- LoweredValInfo lowerIntrinsicCall(
- InvokeExpr* expr,
- IROp intrinsicOp)
- {
- auto type = lowerSimpleType(context, expr->type);
+ if (paramDecl->HasModifier<OutModifier>()
+ || paramDecl->HasModifier<InOutModifier>())
+ {
+ // This is a `out` or `inout` parameter, and so
+ // the argument must be lowered as an l-value.
+
+ LoweredValInfo loweredArg = lowerLValueExpr(context, argExpr);
+
+ // According to our "calling convention" we need to
+ // pass a pointer into the callee.
+ //
+ // A naive approach would be to just take the address
+ // of `loweredArg` above and pass it in, but that
+ // has two issues:
+ //
+ // 1. The l-value might not be something that has a single
+ // well-defined "address" (e.g., `foo.xzy`).
+ //
+ // 2. The l-value argument might actually alias some other
+ // storage that the callee will access (e.g., we are
+ // passing in a global variable, or two `out` parameters
+ // are being passed the same location in an array).
+ //
+ // In each of these cases, the safe option is to create
+ // a temporary variable to use for argument-passing,
+ // and then do copy-in/copy-out around the call.
+
+ LoweredValInfo tempVar = createVar(context, paramType);
+
+ // If the parameter is `in out` or `inout`, then we need
+ // to ensure that we pass in the original value stored
+ // in the argument, which we accomplish by assigning
+ // from the l-value to our temp.
+ if (paramDecl->HasModifier<InModifier>()
+ || paramDecl->HasModifier<InOutModifier>())
+ {
+ assign(context, tempVar, loweredArg);
+ }
- List<IRValue*> irArgs;
- addCallArgs(expr, &irArgs);
- UInt argCount = irArgs.Count();
+ // Now we can pass the address of the temporary variable
+ // to the callee as the actual argument for the `in out`
+ assert(tempVar.flavor == LoweredValInfo::Flavor::Ptr);
+ (*ioArgs).Add(tempVar.val);
- return LoweredValInfo::simple(getBuilder()->emitIntrinsicInst(type, intrinsicOp, argCount, &irArgs[0]));
+ // Finally, after the call we will need
+ // to copy in the other direction: from our
+ // temp back to the original l-value.
+ OutArgumentFixup fixup;
+ fixup.src = tempVar;
+ fixup.dst = loweredArg;
+
+ (*ioFixups).Add(fixup);
+
+ }
+ else
+ {
+ // This is a pure input parameter, and so we will
+ // pass it as an r-value.
+ LoweredValInfo loweredArg = lowerRValueExpr(context, argExpr);
+ addArgs(context, ioArgs, loweredArg);
+ }
+ }
+ }
+
+ // Add arguments that appeared directly in an argument list
+ // to the list of argument values for a call.
+ void addDirectCallArgs(
+ InvokeExpr* expr,
+ DeclRef<Decl> funcDeclRef,
+ List<IRValue*>* ioArgs,
+ List<OutArgumentFixup>* ioFixups)
+ {
+ if (auto callableDeclRef = funcDeclRef.As<CallableDecl>())
+ {
+ addDirectCallArgs(expr, callableDeclRef, ioArgs, ioFixups);
+ }
+ else
+ {
+ SLANG_UNEXPECTED("shouldn't relaly happen");
+ addDirectCallArgs(expr, ioArgs);
+ }
}
void addFuncBaseArgs(
@@ -684,9 +1144,12 @@ struct ExprLoweringVisitor : ExprVisitor<ExprLoweringVisitor, LoweredValInfo>
}
}
- LoweredValInfo lowerSimpleCall(InvokeExpr* expr)
+ void applyOutArgumentFixups(List<OutArgumentFixup> const& fixups)
{
-
+ for (auto fixup : fixups)
+ {
+ assign(context, fixup.dst, fixup.src);
+ }
}
LoweredValInfo visitInvokeExpr(InvokeExpr* expr)
@@ -711,40 +1174,60 @@ struct ExprLoweringVisitor : ExprVisitor<ExprLoweringVisitor, LoweredValInfo>
// arguments that will be part of the call.
List<IRValue*> irArgs;
+ // We will also collect "fixup" actions that need
+ // to be performed after teh call, in order to
+ // copy the final values for `out` parameters
+ // back to their arguments.
+ List<OutArgumentFixup> argFixups;
auto funcExpr = expr->FunctionExpr;
if (auto memberFuncExpr = funcExpr.As<MemberExpr>())
{
- auto loweredBaseVal = lowerExpr(context, memberFuncExpr->BaseExpression);
- addArgs(&irArgs, loweredBaseVal);
+ auto loweredBaseVal = lowerRValueExpr(context, memberFuncExpr->BaseExpression);
+ addArgs(context, &irArgs, loweredBaseVal);
auto funcDeclRef = memberFuncExpr->declRef;
- addDirectCallArgs(expr, &irArgs);
- return emitCallToDeclRef(context, type, funcDeclRef, irArgs);
+ addDirectCallArgs(expr, funcDeclRef, &irArgs, &argFixups);
+ auto result = emitCallToDeclRef(context, type, funcDeclRef, irArgs);
+ applyOutArgumentFixups(argFixups);
+ return result;
}
else if (auto staticMemberFuncExpr = funcExpr.As<StaticMemberExpr>())
{
auto funcDeclRef = staticMemberFuncExpr->declRef;
- addDirectCallArgs(expr, &irArgs);
- return emitCallToDeclRef(context, type, funcDeclRef, irArgs);
+ addDirectCallArgs(expr, funcDeclRef, &irArgs, &argFixups);
+ auto result = emitCallToDeclRef(context, type, funcDeclRef, irArgs);
+ applyOutArgumentFixups(argFixups);
+ return result;
}
else if (auto varExpr = funcExpr.As<VarExpr>())
{
auto funcDeclRef = varExpr->declRef;
- addDirectCallArgs(expr, &irArgs);
- return emitCallToDeclRef(context, type, funcDeclRef, irArgs);
+ addDirectCallArgs(expr, funcDeclRef, &irArgs, &argFixups);
+ auto result = emitCallToDeclRef(context, type, funcDeclRef, irArgs);
+ applyOutArgumentFixups(argFixups);
+ return result;
}
// The default case is to assume that we just have
// an ordinary expression, and can lower it as such.
- LoweredValInfo funcVal = lowerExpr(context, expr->FunctionExpr);
+ LoweredValInfo funcVal = lowerRValueExpr(context, expr->FunctionExpr);
// Now we add any direct arguments from the call expression itself.
addDirectCallArgs(expr, &irArgs);
// Delegate to the logic for invoking a value.
- return emitCallToVal(context, type, funcVal, irArgs.Count(), irArgs.Buffer());
+ auto result = emitCallToVal(context, type, funcVal, irArgs.Count(), irArgs.Buffer());
+
+ // TODO: because of the nature of how the `emitCallToVal` case works
+ // right now, we don't have information on in/out parameters, and
+ // so we can't collect info to apply fixups.
+ //
+ // Once we have a better representation for function types, though,
+ // this should be fixable.
+
+ return result;
}
LoweredValInfo subscriptValue(
@@ -776,15 +1259,6 @@ struct ExprLoweringVisitor : ExprVisitor<ExprLoweringVisitor, LoweredValInfo>
}
- LoweredValInfo visitIndexExpr(IndexExpr* expr)
- {
- auto type = lowerType(context, expr->type);
- auto baseVal = lowerExpr(context, expr->BaseExpression);
- auto indexVal = getSimpleVal(context, lowerExpr(context, expr->IndexExpression));
-
- return subscriptValue(type, baseVal, indexVal);
- }
-
LoweredValInfo extractField(
LoweredTypeInfo fieldType,
LoweredValInfo base,
@@ -824,70 +1298,6 @@ struct ExprLoweringVisitor : ExprVisitor<ExprLoweringVisitor, LoweredValInfo>
return ensureDecl(context, expr->declRef);
}
- LoweredValInfo visitMemberExpr(MemberExpr* expr)
- {
- auto loweredType = lowerType(context, expr->type);
- auto loweredBase = lowerExpr(context, expr->BaseExpression);
-
- auto declRef = expr->declRef;
- if (auto fieldDeclRef = declRef.As<StructField>())
- {
- // Okay, easy enough: we have a reference to a field of a struct type...
-
- auto loweredField = ensureDecl(context, fieldDeclRef);
- return extractField(loweredType, loweredBase, loweredField);
- }
- else if (auto callableDeclRef = declRef.As<CallableDecl>())
- {
- auto loweredFunc = ensureDecl(context, callableDeclRef);
- return LoweredValInfo::boundMember(loweredBase, loweredFunc);
- }
-
- SLANG_UNIMPLEMENTED_X("codegen for subscript expression");
- }
-
- LoweredValInfo visitSwizzleExpr(SwizzleExpr* expr)
- {
- SLANG_UNIMPLEMENTED_X("codegen for swizzle expression");
- }
-
- LoweredValInfo visitDerefExpr(DerefExpr* expr)
- {
- auto loweredType = lowerType(context, expr->type);
- auto loweredBase = lowerExpr(context, expr->base);
-
- // TODO: handle tupel-type for `base`
-
- // The type of the lowered base must by some kind of pointer,
- // in order for a dereference to make senese, so we just
- // need to extract the value type from that pointer here.
- //
- auto loweredBaseVal = getSimpleVal(context, loweredBase);
- auto loweredBaseType = loweredBaseVal->getType();
- switch( loweredBaseType->op )
- {
- case kIROp_PtrType:
- // TODO: should we enumerate these explicitly?
- case kIROp_ConstantBufferType:
- case kIROp_TextureBufferType:
- // Note that we do *not* perform an actual `load` operation
- // here, but rather just use the pointer value to construct
- // an appropriate `LoweredValInfo` representing the underlying
- // dereference.
- //
- // This is important so that an expression like `&((*foo).bar)`
- // (which is desugared from `&foo->bar`) can be handled; such
- // an expression does *not* perform a dereference at runtime,
- // and is just a bit of pointer math.
- //
- return LoweredValInfo::ptr(loweredBaseVal);
-
- default:
- SLANG_UNIMPLEMENTED_X("codegen for deref expression");
- return LoweredValInfo();
- }
- }
-
LoweredValInfo visitTypeCastExpr(TypeCastExpr* expr)
{
SLANG_UNIMPLEMENTED_X("codegen for type cast expression");
@@ -916,8 +1326,8 @@ struct ExprLoweringVisitor : ExprVisitor<ExprLoweringVisitor, LoweredValInfo>
// and right-hand sides, and then peform an assignment
// based on the resulting values.
//
- auto leftVal = lowerExpr(context, expr->left);
- auto rightVal = lowerExpr(context, expr->right);
+ auto leftVal = lowerLValueExpr(context, expr->left);
+ auto rightVal = lowerRValueExpr(context, expr->right);
assign(context, leftVal, rightVal);
// The result value of the assignment expression is
@@ -925,18 +1335,80 @@ struct ExprLoweringVisitor : ExprVisitor<ExprLoweringVisitor, LoweredValInfo>
// to be an l-value).
return leftVal;
}
+};
- LoweredValInfo visitParenExpr(ParenExpr* expr)
+struct LValueExprLoweringVisitor : ExprLoweringVisitorBase<LValueExprLoweringVisitor>
+{
+ // When visiting a swizzle expression in an l-value context,
+ // we need to construct a "sizzled l-value."
+ LoweredValInfo visitSwizzleExpr(SwizzleExpr* expr)
{
- return lowerExpr(context, expr->base);
+ auto irType = lowerSimpleType(context, expr->type);
+ auto loweredBase = lowerRValueExpr(context, expr->base);
+
+ RefPtr<SwizzledLValueInfo> swizzledLValue = new SwizzledLValueInfo();
+ swizzledLValue->type = irType;
+ swizzledLValue->base = loweredBase;
+
+ UInt elementCount = (UInt)expr->elementCount;
+ swizzledLValue->elementCount = elementCount;
+ for (UInt ii = 0; ii < elementCount; ++ii)
+ {
+ swizzledLValue->elementIndices[ii] = (UInt) expr->elementIndices[ii];
+ }
+
+ context->shared->extValues.Add(swizzledLValue);
+ return LoweredValInfo::swizzledLValue(swizzledLValue);
+ }
+
+};
+
+struct RValueExprLoweringVisitor : ExprLoweringVisitorBase<RValueExprLoweringVisitor>
+{
+ // A swizzle in an r-value context can save time by just
+ // emitting the swizzle instuctions directly.
+ LoweredValInfo visitSwizzleExpr(SwizzleExpr* expr)
+ {
+ auto irType = lowerSimpleType(context, expr->type);
+ auto irBase = getSimpleVal(context, lowerRValueExpr(context, expr->base));
+
+ auto builder = getBuilder();
+
+ auto irIntType = builder->getBaseType(BaseType::Int);
+
+ UInt elementCount = (UInt)expr->elementCount;
+ IRValue* irElementIndices[4];
+ for (UInt ii = 0; ii < elementCount; ++ii)
+ {
+ irElementIndices[ii] = builder->getIntValue(
+ irIntType,
+ (IRIntegerValue)expr->elementIndices[ii]);
+ }
+
+ auto irSwizzle = builder->emitSwizzle(
+ irType,
+ irBase,
+ elementCount,
+ &irElementIndices[0]);
+
+ return LoweredValInfo::simple(irSwizzle);
}
};
-LoweredValInfo lowerExpr(
+LoweredValInfo lowerLValueExpr(
+ IRGenContext* context,
+ Expr* expr)
+{
+ LValueExprLoweringVisitor visitor;
+ visitor.context = context;
+ return visitor.dispatch(expr);
+}
+
+LoweredValInfo lowerRValueExpr(
IRGenContext* context,
Expr* expr)
{
- ExprLoweringVisitor visitor;
+ RValueExprLoweringVisitor visitor;
visitor.context = context;
return visitor.dispatch(expr);
}
@@ -952,6 +1424,185 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor>
SLANG_UNIMPLEMENTED_X("stmt catch-all");
}
+ // Create a basic block in the current function,
+ // so that it can be used for a label.
+ IRBlock* createBlock()
+ {
+ return getBuilder()->createBlock();
+ }
+
+ // Insert a block at the current location (ending
+ // the previous block with an unconditional jump
+ // if needed).
+ void insertBlock(IRBlock* block)
+ {
+ auto builder = getBuilder();
+ auto parent = builder->parentInst;
+
+ IRBlock* prevBlock = nullptr;
+ IRFunc* parentFunc = nullptr;
+
+ switch (parent->op)
+ {
+ case kIROp_Block:
+ prevBlock = (IRBlock*)parent;
+ parentFunc = prevBlock->getParent();
+ break;
+
+ default:
+ SLANG_UNEXPECTED("bad parent kind for block");
+ return;
+ }
+
+ // If the previous block doesn't already have
+ // a terminator instruction, then be sure to
+ // emit a branch to the new block.
+ if (!isTerminatorInst(prevBlock->lastChild))
+ {
+ builder->emitBranch(block);
+ }
+
+ builder->parentInst = parentFunc;
+ builder->addInst(block);
+ builder->parentInst = block;
+ }
+
+ // Start a new block at the current location.
+ // This is just the composition of `createBlock`
+ // and `insertBlock`.
+ IRBlock* startBlock()
+ {
+ auto block = createBlock();
+ insertBlock(block);
+ return block;
+ }
+
+ void visitIfStmt(IfStmt* stmt)
+ {
+ auto builder = getBuilder();
+
+ auto condExpr = stmt->Predicate;
+ auto thenStmt = stmt->PositiveStatement;
+ auto elseStmt = stmt->NegativeStatement;
+
+ auto irCond = getSimpleVal(context,
+ lowerRValueExpr(context, condExpr));
+
+ if (elseStmt)
+ {
+ auto thenBlock = createBlock();
+ auto elseBlock = createBlock();
+ auto afterBlock = createBlock();
+
+ builder->emitIfElse(irCond, thenBlock, elseBlock, afterBlock);
+
+ insertBlock(thenBlock);
+ lowerStmt(context, thenStmt);
+ builder->emitBranch(afterBlock);
+
+ insertBlock(elseBlock);
+ lowerStmt(context, elseStmt);
+
+ insertBlock(afterBlock);
+ }
+ else
+ {
+ auto thenBlock = createBlock();
+ auto afterBlock = createBlock();
+
+ builder->emitIf(irCond, thenBlock, afterBlock);
+
+ insertBlock(thenBlock);
+ lowerStmt(context, thenStmt);
+
+ insertBlock(afterBlock);
+ }
+ }
+
+ void addLoopDecorations(
+ IRInst* inst,
+ Stmt* stmt)
+ {
+ for(auto attr : stmt->GetModifiersOfType<HLSLUncheckedAttribute>())
+ {
+ // TODO: We should actually catch these attributes during
+ // semantic checking, so that they have a strongly-typed
+ // representation in the AST.
+ if(getText(attr->getName()) == "unroll")
+ {
+ auto decoration = getBuilder()->addDecoration<IRLoopControlDecoration>(inst);
+ decoration->mode = kIRLoopControl_Unroll;
+ }
+ }
+ }
+
+ void visitForStmt(ForStmt* stmt)
+ {
+ auto builder = getBuilder();
+
+ // The initializer clause for the statement
+ // can always safetly be emitted to the current block.
+ if (auto initStmt = stmt->InitialStatement)
+ {
+ lowerStmt(context, initStmt);
+ }
+
+ // We will create blocks for the various places
+ // we need to jump to inside the control flow,
+ // including the blocks that will be referenced
+ // by `continue` or `break` statements.
+ auto loopHead = createBlock();
+ auto bodyLabel = createBlock();
+ auto breakLabel = createBlock();
+ auto continueLabel = createBlock();
+
+ // TODO: register `loopHead` as the target for a
+ // `continue` statement.
+
+ // Emit the branch that will start out loop,
+ // and then insert the block for the head.
+
+ auto loopInst = builder->emitLoop(
+ loopHead,
+ breakLabel,
+ continueLabel);
+
+ addLoopDecorations(loopInst, stmt);
+
+ insertBlock(loopHead);
+
+ // Now that we are within the header block, we
+ // want to emit the expression for the loop condition:
+ if (auto condExpr = stmt->PredicateExpression)
+ {
+ auto irCondition = getSimpleVal(context,
+ lowerRValueExpr(context, stmt->PredicateExpression));
+
+ // Now we want to `break` if the loop condition is false.
+ builder->emitLoopTest(
+ irCondition,
+ bodyLabel,
+ breakLabel);
+ }
+
+ // Emit the body of the loop
+ insertBlock(bodyLabel);
+ lowerStmt(context, stmt->Statement);
+
+ // Insert the `continue` block
+ insertBlock(continueLabel);
+ if (auto incrExpr = stmt->SideEffectExpression)
+ {
+ lowerRValueExpr(context, incrExpr);
+ }
+
+ // At the end of the body we need to jump back to the top.
+ builder->emitBranch(loopHead);
+
+ // Finally we insert the label that a `break` will jump to
+ insertBlock(breakLabel);
+ }
+
void visitExpressionStmt(ExpressionStmt* stmt)
{
// The statement evaluates an expression
@@ -960,7 +1611,11 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor>
// lower the expression, and don't use
// the result.
//
- lowerExpr(context, stmt->Expression);
+ // Note that we lower using the l-value path,
+ // so that an expression statement that names
+ // a location (but doesn't load from it)
+ // will not actually emit a load.
+ lowerLValueExpr(context, stmt->Expression);
}
void visitDeclStmt(DeclStmt* stmt)
@@ -1004,7 +1659,7 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor>
// a value first, and then emit the resulting value.
if( auto expr = stmt->Expression )
{
- auto loweredExpr = lowerExpr(context, expr);
+ auto loweredExpr = lowerRValueExpr(context, expr);
getBuilder()->emitReturn(getSimpleVal(context, loweredExpr));
}
@@ -1026,9 +1681,15 @@ void lowerStmt(
void assign(
IRGenContext* context,
- LoweredValInfo const& left,
- LoweredValInfo const& right)
+ LoweredValInfo const& inLeft,
+ LoweredValInfo const& inRight)
{
+ LoweredValInfo left = inLeft;
+ LoweredValInfo right = inRight;
+
+ auto builder = context->irBuilder;
+
+top:
switch (left.flavor)
{
case LoweredValInfo::Flavor::Ptr:
@@ -1036,8 +1697,8 @@ void assign(
{
case LoweredValInfo::Flavor::Simple:
case LoweredValInfo::Flavor::Ptr:
+ case LoweredValInfo::Flavor::SwizzledLValue:
{
- auto builder = context->irBuilder;
builder->emitStore(
left.val,
getSimpleVal(context, right));
@@ -1050,6 +1711,90 @@ void assign(
}
break;
+ case LoweredValInfo::Flavor::SwizzledLValue:
+ {
+ // The `left` value is of the form `<someLValue>.<swizzleElements>`.
+ //
+ // We could conceivably define a custom "swizzled store" instruction
+ // that would handle the common case where the base l-value is
+ // a simple lvalue (`LowerdValInfo::Flavor::Ptr`):
+ //
+ // float4 foo;
+ // foo.zxy = float3(...);
+ //
+ // However, this doesn't handle complex cases like the following:
+ //
+ // RWStructureBuffer<float4> foo;
+ // ...
+ // foo[index].xzy = float3(...);
+ //
+ // In a case like that, we really need to lower through a temp:
+ //
+ // float4 tmp = foo[index];
+ // tmp.xzy = float3(...);
+ // foo[index] = tmp;
+ //
+ // We want to handle the general case, we we might as well
+ // try to handle everything uniformly.
+ //
+ auto swizzleInfo = left.getSwizzledLValueInfo();
+ auto type = swizzleInfo->type;
+ auto loweredBase = swizzleInfo->base;
+
+ // Load from the base value:
+ IRInst* irLeftVal = getSimpleVal(context, loweredBase);
+ auto irRightVal = getSimpleVal(context, right);
+
+ // Now apply the swizzle
+ IRInst* irSwizzled = builder->emitSwizzleSet(
+ irLeftVal->getType(),
+ irLeftVal,
+ irRightVal,
+ swizzleInfo->elementCount,
+ swizzleInfo->elementIndices);
+
+ // And finally, store the value back where we got it.
+ //
+ // Note: this is effectively a recursive call to
+ // `assign()`, so we do a simple tail-recursive call here.
+ left = loweredBase;
+ right = LoweredValInfo::simple(irSwizzled);
+ goto top;
+ }
+ break;
+
+ case LoweredValInfo::Flavor::BoundSubscript:
+ {
+ // The `left` value refers to a subscript operation on
+ // a resource type, bound to particular arguments, e.g.:
+ // `someStructuredBuffer[index]`.
+ //
+ // When storing to such a value, we need to emit a call
+ // to the appropriate builtin "setter" accessor.
+ auto subscriptInfo = left.getBoundSubscriptInfo();
+ auto type = subscriptInfo->type;
+
+ // Search for an appropriate "setter" declaration
+ for (auto setterDeclRef : getMembersOfType<SetterDecl>(subscriptInfo->declRef))
+ {
+ auto allArgs = subscriptInfo->args;
+
+ addArgs(context, &allArgs, right);
+
+ emitCallToDeclRef(
+ context,
+ builder->getVoidType(),
+ setterDeclRef,
+ allArgs);
+ return;
+ }
+
+ // No setter found? Then we have an error!
+ SLANG_UNEXPECTED("no setter found");
+ break;
+ }
+ break;
+
default:
SLANG_UNIMPLEMENTED_X("assignment");
break;
@@ -1120,33 +1865,44 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
auto varType = lowerType(context, decl->getType());
- LoweredValInfo varVal;
+ // TODO: If the variable is marked `static` then we need to
+ // deal with it specially: we should move its allocation out
+ // to the global scope, and then we have to deal with its
+ // initializer expression a bit carefully (it should only
+ // be initialized on-demand at its first use).
+
+ // Some qualifiers on a variable will change how we allocate it,
+ // so we need to reflect that somehow. The first example
+ // we run into is the `groupshared` qualifier, which marks
+ // a variable in a compute shader as having per-group allocation
+ // rather than the traditional per-thread (or rather per-thread
+ // per-activation-record) allocation.
+ //
+ // Options include:
+ //
+ // - Use a distinct allocation opration, so that the type
+ // of the variable address/value is unchanged.
+ //
+ // - Add a notion of an "address space" to pointer types,
+ // so that we can allocate things in distinct spaces.
+ //
+ // - Add a notion of a "rate" so that we can declare a
+ // variable with a distinct rate.
+ //
+ // For now we might do the expedient thing and handle this
+ // via a notion of an "address space."
- switch( varType.flavor )
+ IRAddressSpace addressSpace = kIRAddressSpace_Default;
+ if (decl->HasModifier<HLSLGroupSharedModifier>())
{
- case LoweredTypeInfo::Flavor::Simple:
- {
- auto irAlloc = getBuilder()->emitVar(getSimpleType(varType));
-
- getBuilder()->addHighLevelDeclDecoration(irAlloc, decl);
-
- if (getLayout())
- {
- getBuilder()->addLayoutDecoration(irAlloc, getLayout());
- }
-
-
- varVal = LoweredValInfo::ptr(irAlloc);
- }
- break;
-
- default:
- SLANG_UNIMPLEMENTED_X("struct field type");
+ addressSpace = kIRAddressSpace_GroupShared;
}
+ LoweredValInfo varVal = createVar(context, varType, decl, getLayout(), addressSpace);
+
if( auto initExpr = decl->initExpr )
{
- auto initVal = lowerExpr(context, initExpr);
+ auto initVal = lowerRValueExpr(context, initExpr);
assign(context, varVal, initVal);
}
@@ -1241,6 +1997,8 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
IRParam* irParam = subBuilder->emitParam(irParamType);
+ subBuilder->addHighLevelDeclDecoration(irParam, paramDecl);
+
DeclRef<ParamDecl> paramDeclRef = makeDeclRef(paramDecl.Ptr());
LoweredValInfo irParamVal = LoweredValInfo::simple(irParam);
@@ -1278,6 +2036,21 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
lowerStmt(subContext, decl->Body);
+ // We need to carefully add a terminator instruction to the end
+ // of the body, in case the user didn't do so.
+ if (!isTerminatorInst(subContext->irBuilder->parentInst->lastChild))
+ {
+ if (irResultType->op == kIROp_VoidType)
+ {
+ subContext->irBuilder->emitReturn();
+ }
+ else
+ {
+ SLANG_UNEXPECTED("Needed a return here");
+ subContext->irBuilder->emitReturn();
+ }
+ }
+
getBuilder()->addHighLevelDeclDecoration(irFunc, decl);
getBuilder()->addInst(irFunc);
@@ -1365,11 +2138,48 @@ static void lowerEntryPointToIR(
EntryPointRequest* entryPointRequest,
EntryPointLayout* entryPointLayout)
{
- auto entryPointFunc = entryPointLayout->entryPoint;
+ // First, lower the entry point like an ordinary function
+ auto entryPointFuncDecl = entryPointLayout->entryPoint;
+ auto loweredEntryPointFunc = lowerDecl(context, entryPointFuncDecl, entryPointLayout);
+ auto irFunc = getSimpleVal(context, loweredEntryPointFunc);
+
+ auto builder = context->irBuilder;
+
+ // We are going to attach all the entry-point-specific information
+ // to the declaration as meta-data decorations for now.
+ //
+ // I'm not convinced this is the right way to go, but it is
+ // the easiest and most expedient thing.
+ //
+ auto profile = entryPointRequest->profile;
+ auto stage = profile.GetStage();
- // TODO: entry point lowering is probably *not* just like lowering a function...
+ auto entryPointDecoration = builder->addDecoration<IREntryPointDecoration>(irFunc);
+ entryPointDecoration->profile = profile;
- lowerDecl(context, entryPointFunc, entryPointLayout);
+ // Next, we need to start attaching the meta-data that is
+ // required based on the particular stage we are targetting:
+ switch (stage)
+ {
+ case Stage::Compute:
+ {
+ // We need to attach information about the thread group size here.
+ auto threadGroupSizeDecoration = builder->addDecoration<IRComputeThreadGroupSizeDecoration>(irFunc);
+ static const UInt kAxisCount = 3;
+
+ // TODO: this is kind of gross because we are using a public
+ // reflection API function, rather than some kind of internal
+ // utility it forwards to...
+ spReflectionEntryPoint_getComputeThreadGroupSize(
+ (SlangReflectionEntryPoint*)entryPointLayout,
+ kAxisCount,
+ &threadGroupSizeDecoration->sizeAlongAxis[0]);
+ }
+ break;
+
+ default:
+ break;
+ }
}
IRModule* lowerEntryPointToIR(
@@ -1412,4 +2222,17 @@ IRModule* lowerEntryPointToIR(
}
+String emitSlangIRAssemblyForEntryPoint(
+ EntryPointRequest* entryPoint)
+{
+ auto compileRequest = entryPoint->compileRequest;
+ auto irModule = lowerEntryPointToIR(
+ entryPoint,
+ compileRequest->layout.Ptr(),
+ // TODO: we need to pick the target more carefully here
+ CodeGenTarget::HLSL);
+
+ return getSlangIRAssembly(irModule);
+}
+
}
diff --git a/source/slang/options.cpp b/source/slang/options.cpp
index 669ab1827..a87692700 100644
--- a/source/slang/options.cpp
+++ b/source/slang/options.cpp
@@ -300,6 +300,8 @@ struct OptionsParser
CASE(spirv, SPIRV);
CASE(spirv-assembly, SPIRV_ASM);
+ CASE(slang-ir, IR);
+ CASE(slang-ir-assembly, IR_ASM);
CASE(none, TARGET_NONE);
#undef CASE
diff --git a/source/slang/profile-defs.h b/source/slang/profile-defs.h
index 84153ee46..513ba3078 100644
--- a/source/slang/profile-defs.h
+++ b/source/slang/profile-defs.h
@@ -47,6 +47,8 @@ LANGUAGE(GLSL_ES, glsl_es)
LANGUAGE(GLSL_VK, glsl_vk)
LANGUAGE(SPIRV, spirv)
LANGUAGE(SPIRV_GL, spirv_gl)
+LANGUAGE(SlangIR, slang_ir)
+LANGUAGE(SlangIRAssembly, slang_ir_assembly)
LANGUAGE_ALIAS(GLSL, glsl_gl)
LANGUAGE_ALIAS(SPIRV, spirv_vk)
diff --git a/source/slang/slang.vcxproj b/source/slang/slang.vcxproj
index 1f55138e4..08a448a6d 100644
--- a/source/slang/slang.vcxproj
+++ b/source/slang/slang.vcxproj
@@ -172,6 +172,7 @@
<ClInclude Include="emit.h" />
<ClInclude Include="expr-defs.h" />
<ClInclude Include="ir-inst-defs.h" />
+ <ClInclude Include="ir-insts.h" />
<ClInclude Include="ir.h" />
<ClInclude Include="lexer.h" />
<ClInclude Include="lookup.h" />
diff --git a/source/slang/slang.vcxproj.filters b/source/slang/slang.vcxproj.filters
index 9a85ce966..111b23d36 100644
--- a/source/slang/slang.vcxproj.filters
+++ b/source/slang/slang.vcxproj.filters
@@ -39,6 +39,7 @@
<ClInclude Include="ir.h" />
<ClInclude Include="lower-to-ir.h" />
<ClInclude Include="ir-inst-defs.h" />
+ <ClInclude Include="ir-insts.h" />
</ItemGroup>
<ItemGroup>
<ClCompile Include="check.cpp" />
@@ -65,8 +66,8 @@
<ClCompile Include="lower-to-ir.cpp" />
</ItemGroup>
<ItemGroup>
- <None Include="core.meta.slang" />
- <None Include="glsl.meta.slang" />
- <None Include="hlsl.meta.slang" />
+ <CustomBuild Include="core.meta.slang" />
+ <CustomBuild Include="glsl.meta.slang" />
+ <CustomBuild Include="hlsl.meta.slang" />
</ItemGroup>
</Project> \ No newline at end of file
diff --git a/source/slang/type-layout.cpp b/source/slang/type-layout.cpp
index 00a625427..ec71df37d 100644
--- a/source/slang/type-layout.cpp
+++ b/source/slang/type-layout.cpp
@@ -621,6 +621,8 @@ LayoutRulesFamilyImpl* GetLayoutRulesFamilyImpl(CodeGenTarget target)
case CodeGenTarget::HLSL:
case CodeGenTarget::DXBytecode:
case CodeGenTarget::DXBytecodeAssembly:
+ case CodeGenTarget::SlangIR:
+ case CodeGenTarget::SlangIRAssembly:
return &kHLSLLayoutRulesFamilyImpl;
case CodeGenTarget::GLSL:
diff --git a/tests/ir/loop.slang b/tests/ir/loop.slang
new file mode 100644
index 000000000..d637e4536
--- /dev/null
+++ b/tests/ir/loop.slang
@@ -0,0 +1,33 @@
+//TEST:SIMPLE:-target slang-ir-assembly -profile cs_4_0 -entry main
+
+#define GROUP_THREAD_COUNT 64
+
+StructuredBuffer<float4> input;
+RWStructuredBuffer<float4> output;
+
+groupshared float4 s[GROUP_THREAD_COUNT];
+
+[numthreads(GROUP_THREAD_COUNT, 1, 1)]
+void main(
+ uint dispatchThreadID : SV_DispatchThreadIndex,
+ uint groupThreadID : SV_GroupThreadIndex,
+ uint groupID : SV_GroupIndex )
+{
+ // the actual algorithm being done here is bogus
+
+ // load shared memory
+ s[groupThreadID] = input[dispatchThreadID];
+
+ // do some sum passes
+ for(uint stride = 1; stride < GROUP_THREAD_COUNT; stride <<= 1)
+ {
+ GroupMemoryBarrierWithGroupSync();
+
+ s[groupThreadID] += s[groupThreadID - stride];
+ }
+
+ GroupMemoryBarrierWithGroupSync();
+
+ output[dispatchThreadID] = s[0];
+}
+
diff --git a/tests/ir/loop.slang.expected b/tests/ir/loop.slang.expected
new file mode 100644
index 000000000..2dc091b0e
--- /dev/null
+++ b/tests/ir/loop.slang.expected
@@ -0,0 +1,67 @@
+result code = 0
+standard error = {
+}
+standard output = {
+let %41 : Ptr<Array<Vec<Float32,4>,64>,1> = var()
+let %68 : Ptr<StructuredBuffer<Vec<Float32,4>>,0> = var()
+let %243 : Ptr<RWStructuredBuffer<Vec<Float32,4>>,0> = var()
+
+func %1(
+ param %7 : UInt32,
+ param %10 : UInt32,
+ param %13 : UInt32)
+{
+block %4:
+ let %47 : Ptr<Vec<Float32,4>,0> = getElementPtr(%41, %10)
+ let %69 : StructuredBuffer<Vec<Float32,4>> = load(%68)
+ let %72 : Vec<Float32,4> = bufferLoad(%69, %7)
+ store(%47, %72)
+ let %81 : Ptr<UInt32,0> = var()
+ let %89 : UInt32 = construct(1)
+ store(%81, %89)
+ loop(%94, %100, %103)
+
+block %94:
+ let %110 : UInt32 = load(%81)
+ let %119 : UInt32 = construct(64)
+ let %120 : Bool = cmpLT(%110, %119)
+ loopTest(%120, %97, %100)
+
+block %97:
+ GroupMemoryBarrierWithGroupSync()
+ let %147 : Ptr<Vec<Float32,4>,0> = getElementPtr(%41, %10)
+ let %152 : Ptr<Vec<Float32,4>,0> = var()
+ let %153 : Vec<Float32,4> = load(%147)
+ store(%152, %153)
+ let %174 : UInt32 = load(%81)
+ let %175 : UInt32 = sub(%10, %174)
+ let %180 : Ptr<Vec<Float32,4>,0> = getElementPtr(%41, %175)
+ let %181 : Vec<Float32,4> = load(%180)
+ let %182 : Vec<Float32,4> = load(%152)
+ let %183 : Vec<Float32,4> = add(%182, %181)
+ store(%152, %183)
+ let %186 : Vec<Float32,4> = load(%152)
+ store(%147, %186)
+ unconditionalBranch(%103)
+
+block %103:
+ let %199 : Ptr<UInt32,0> = var()
+ let %200 : UInt32 = load(%81)
+ store(%199, %200)
+ let %211 : UInt32 = construct(1)
+ let %212 : UInt32 = load(%199)
+ let %213 : UInt32 = shl(%212, %211)
+ store(%199, %213)
+ let %216 : UInt32 = load(%199)
+ store(%81, %216)
+ unconditionalBranch(%94)
+
+block %100:
+ GroupMemoryBarrierWithGroupSync()
+ let %244 : RWStructuredBuffer<Vec<Float32,4>> = load(%243)
+ let %260 : Ptr<Vec<Float32,4>,0> = getElementPtr(%41, 0)
+ let %261 : Vec<Float32,4> = load(%260)
+ bufferStore(%244, %7, %261)
+ return_void()
+}
+}