summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorTim Foley <tfoleyNV@users.noreply.github.com>2017-10-04 13:54:25 -0700
committerGitHub <noreply@github.com>2017-10-04 13:54:25 -0700
commit54f016e7ef36b7505bf47d188cf4b7e1fdc443a4 (patch)
treef8a385c8a3bbac807c2c0d08a9b1e4cd208db95c
parent8a0ebb9fa25fd44def17b03b3f8aa1a33ad77940 (diff)
IR: overhaul IR design/implementation (#195)
* IR: overhaul IR design/implementation Closes #192 Closes #188 This is a major overhaul of how the IR is implemented, with the primary goal of just using the AST-level type representation as the IR's type representation, rather than inventing an entire shadow set of types (as captured in issue #192). One consequence of this choice is that types in the IR are no longer explicit "instructions" and are not represented as ordinary operands (so a bunch of `+ 1` cases end up going away when enumerating ordinary operands). Along the way I also got rid of the embedded IDs in the IR (issue #188) because this wasn't too hard to deal with at the same time. Another related change was to split the `IRValue` and `IRInst` cases, so that there are values that are not also instructions. Non-instruction values are now used to represent literals, references to declarations, and would eventually be used for an `undef` value if we need one. IR functions, global variables, and basic blocks are all values (because they can appear as operands), but not instructions. The main benefit of this approach is that the top-level structure of a bytecode (BC) module is much simpler to understand and walk, and BC-level types are represented much more directly (such that we could conceivably use them for reflection soon). * fixup: 64-bit build fix * fixup: try to silence clang's pedantic dependent-type errors * fixup: bug in VM loading of constants
-rw-r--r--source/slang/bytecode.cpp555
-rw-r--r--source/slang/bytecode.h102
-rw-r--r--source/slang/check.cpp27
-rw-r--r--source/slang/compiler.h8
-rw-r--r--source/slang/core.meta.slang11
-rw-r--r--source/slang/core.meta.slang.h12
-rw-r--r--source/slang/emit.cpp487
-rw-r--r--source/slang/ir-inst-defs.h29
-rw-r--r--source/slang/ir-insts.h236
-rw-r--r--source/slang/ir.cpp1065
-rw-r--r--source/slang/ir.h261
-rw-r--r--source/slang/lower-to-ir.cpp320
-rw-r--r--source/slang/lower.cpp19
-rw-r--r--source/slang/syntax.cpp132
-rw-r--r--source/slang/syntax.h23
-rw-r--r--source/slang/type-defs.h63
-rw-r--r--source/slang/vm.cpp616
-rw-r--r--tests/ir/loop.slang.expected129
-rw-r--r--tools/eval-test/main.cpp6
19 files changed, 2368 insertions, 1733 deletions
diff --git a/source/slang/bytecode.cpp b/source/slang/bytecode.cpp
index e412a5b94..854d9bf34 100644
--- a/source/slang/bytecode.cpp
+++ b/source/slang/bytecode.cpp
@@ -19,8 +19,8 @@ struct SharedBytecodeGenerationContext;
template<typename T>
struct BytecodeGenerationPtr
{
- SharedBytecodeGenerationContext* sharedContext;
UInt offset;
+ SharedBytecodeGenerationContext* sharedContext;
BytecodeGenerationPtr()
: sharedContext(nullptr)
@@ -49,31 +49,40 @@ struct BytecodeGenerationPtr
, offset(ptr.offset)
{}
- operator BCPtr<T>()
+ template<typename U>
+ BytecodeGenerationPtr<U> bitCast() const
+ {
+ return BytecodeGenerationPtr<U>(sharedContext, offset);
+ }
+
+ operator BCPtr<T>() const
{
return BCPtr<T>(getPtr());
}
- T* operator->()
+ T* operator->() const
{
return getPtr();
}
- T& operator*()
+ T& operator*() const
{
return *getPtr();
}
- T& operator[](UInt index)
+ T& operator[](UInt index) const
{
return getPtr()[index];
}
- BytecodeGenerationPtr<T> operator+(Int index)
+ BytecodeGenerationPtr<T> operator+(Int index) const
{
+ UInt size = sizeof(T);
+ Int delta = index * sizeof(T);
+ UInt newOffset = offset + delta;
return BytecodeGenerationPtr<T>(
sharedContext,
- offset + index*sizeof(T));
+ newOffset);
}
T* getPtr() const;
@@ -93,14 +102,27 @@ struct SharedBytecodeGenerationContext
// The final generated bytecode stream
List<uint8_t> bytecode;
- // Map from a global symbol to its global ID
- Dictionary<IRInst*, Int> mapGlobalSymbolToGLobalID;
+ // Map from an IR value to a global entity
+ // that encodes it:
+ Dictionary<IRValue*, BCConst> mapValueToGlobal;
+
+ // Types that have been emitted
+ List<BytecodeGenerationPtr<BCType>> bcTypes;
+ Dictionary<Type*, UInt> mapTypeToID;
+
+ // Compile-time constant values that need
+ // to be emitted...
+ List<IRValue*> constants;
};
struct BytecodeGenerationContext
{
SharedBytecodeGenerationContext* shared;
+ // The bytecode of the current symbol being
+ // output.
+ List<uint8_t> currentBytecode;
+
// The function that is in scope for this context
IRFunc* currentIRFunc;
@@ -110,11 +132,7 @@ struct BytecodeGenerationContext
// Map an instruction to its ID for use local
// to the current context
- Dictionary<IRInst*, Int> mapInstToLocalID;
-
- // Map an instruction to the ID for its auxiliary
- // symbol data
- Dictionary<IRInst*, UInt> mapInstToNestedID;
+ Dictionary<IRValue*, Int> mapInstToLocalID;
};
template<typename T>
@@ -169,7 +187,7 @@ void encodeUInt8(
BytecodeGenerationContext* context,
uint8_t value)
{
- context->shared->bytecode.Add(value);
+ context->currentBytecode.Add(value);
}
void encodeUInt(
@@ -220,50 +238,98 @@ void encodeSInt(
encodeUInt(context, uValue);
}
-Int getLocalID(
+BCConst getGlobalValue(
BytecodeGenerationContext* context,
- IRInst* inst)
+ IRValue* value)
{
- Int localID = 0;
- if( context->mapInstToLocalID.TryGetValue(inst, localID) )
- {
- return localID;
- }
+ BCConst bcConst;
+ if( context->shared->mapValueToGlobal.TryGetValue(value, bcConst) )
+ return bcConst;
+
+ // Next we need to check for things that can be mapped to
+ // global IDs on the fly.
- Int globalID = 0;
- if( context->shared->mapGlobalSymbolToGLobalID.TryGetValue(inst, globalID) )
+ switch( value->op )
{
- BCConst bcConst;
- bcConst.globalID = globalID;
+ case kIROp_IntLit:
+ {
+ UInt constID = context->shared->constants.Count();
+ context->shared->constants.Add(value);
- UInt remappedSymbolIndex = context->remappedGlobalSymbols.Count();
- context->remappedGlobalSymbols.Add(bcConst);
+ BCConst bcConst;
+ bcConst.flavor = kBCConstFlavor_Constant;
+ bcConst.id = constID;
- localID = ~remappedSymbolIndex;
- context->mapInstToLocalID.Add(inst, localID);
- return localID;
+ context->shared->mapValueToGlobal.Add(value, bcConst);
+
+ return bcConst;
+ }
+ break;
+
+ default:
+ break;
}
SLANG_UNEXPECTED("no ID for inst");
- return -9999;
+ bcConst.flavor = (BCConstFlavor) -1;
+ bcConst.id = -9999;
+ return bcConst;
+}
+
+Int getLocalID(
+ BytecodeGenerationContext* context,
+ IRValue* value)
+{
+ Int localID = 0;
+ if( context->mapInstToLocalID.TryGetValue(value, localID) )
+ {
+ return localID;
+ }
+
+ BCConst bcConst = getGlobalValue(context, value);
+ UInt remappedSymbolIndex = context->remappedGlobalSymbols.Count();
+ context->remappedGlobalSymbols.Add(bcConst);
+
+ localID = ~remappedSymbolIndex;
+ context->mapInstToLocalID.Add(value, localID);
+ return localID;
}
void encodeOperand(
BytecodeGenerationContext* context,
- IRInst* operand)
+ IRValue* operand)
{
auto id = getLocalID(context, operand);
encodeSInt(context, id);
}
-bool opHasResult(IRInst* inst)
+uint32_t getTypeID(
+ BytecodeGenerationContext* context,
+ Type* type);
+
+void encodeOperand(
+ BytecodeGenerationContext* context,
+ IRType* type)
+{
+ encodeUInt(context, getTypeID(context, type));
+}
+
+bool opHasResult(IRValue* inst)
{
auto type = inst->getType();
- if( !type || type->op != kIROp_VoidType )
+ if (!type) return false;
+
+ // As a bit of a hack right now, we need to check whether
+ // the function returns the distinguished `Void` type,
+ // since that is conceptually the same as "not returning
+ // a value."
+ if (auto basicType = dynamic_cast<BasicExpressionType*>(type))
{
- return true;
+ if (basicType->baseType == BaseType::Void)
+ return false;
}
- return false;
+
+ return true;
}
void generateBytecodeForInst(
@@ -282,15 +348,16 @@ void generateBytecodeForInst(
//
auto argCount = inst->getArgCount();
+ auto type = inst->getType();
encodeUInt(context, inst->op);
+ encodeOperand(context, inst->getType());
encodeUInt(context, argCount);
for( UInt aa = 0; aa < argCount; ++aa )
{
encodeOperand(context, inst->getArg(aa));
}
- auto type = inst->getType();
- if( type && type->op == kIROp_VoidType )
+ if (!opHasResult(inst))
{
// This instructions has no type, so don't emit a destination
}
@@ -354,6 +421,7 @@ void generateBytecodeForInst(
}
break;
+#if 0
case kIROp_Func:
{
encodeUInt(context, inst->op);
@@ -368,6 +436,7 @@ void generateBytecodeForInst(
encodeOperand(context, inst);
}
break;
+#endif
case kIROp_Store:
{
@@ -375,9 +444,9 @@ void generateBytecodeForInst(
// We need to encode the type being stored, to make
// our lives easier.
- encodeOperand(context, inst->getArg(2)->getType());
+ encodeOperand(context, inst->getArg(1)->getType());
+ encodeOperand(context, inst->getArg(0));
encodeOperand(context, inst->getArg(1));
- encodeOperand(context, inst->getArg(2));
}
break;
@@ -385,33 +454,166 @@ void generateBytecodeForInst(
{
encodeUInt(context, inst->op);
encodeOperand(context, inst->getType());
- encodeOperand(context, inst->getArg(1));
+ encodeOperand(context, inst->getArg(0));
encodeOperand(context, inst);
}
break;
}
}
-Int getIDForGlobalSymbol(
+BytecodeGenerationPtr<BCType> emitBCType(
+ BytecodeGenerationContext* context,
+ Type* type,
+ IROp op,
+ BytecodeGenerationPtr<uint8_t> const* args,
+ UInt argCount)
+{
+ UInt size = sizeof(BCType)
+ + argCount * sizeof(BCPtr<void>);
+
+ BytecodeGenerationPtr<uint8_t> bcAllocation(
+ context->shared,
+ allocateRaw(context, size, alignof(BCPtr<void>)));
+
+ BytecodeGenerationPtr<BCType> bcType = bcAllocation.bitCast<BCType>();
+ auto bcArgs = (bcType + 1).bitCast<BCPtr<uint8_t>>();
+
+ bcType->op = op;
+ bcType->argCount = argCount;
+
+ for(UInt aa = 0; aa < argCount; ++aa)
+ {
+ bcArgs[aa] = args[aa];
+ }
+
+ UInt id = context->shared->bcTypes.Count();
+ context->shared->mapTypeToID.Add(type, id);
+ context->shared->bcTypes.Add(bcType);
+ bcType->id = id;
+
+ return bcType;
+}
+
+BytecodeGenerationPtr<BCType> emitBCVarArgType(
+ BytecodeGenerationContext* context,
+ Type* type,
+ IROp op,
+ List<BytecodeGenerationPtr<uint8_t>> args)
+{
+ return emitBCType(context, type, op, args.Buffer(), args.Count());
+}
+
+BytecodeGenerationPtr<BCType> emitBCType(
BytecodeGenerationContext* context,
- IRInst* inst)
+ Type* type,
+ IROp op)
+{
+ return emitBCType(context, type, op, nullptr, 0);
+}
+
+BytecodeGenerationPtr<BCType> emitBCType(
+ BytecodeGenerationContext* context,
+ Type* type);
+
+// Emit a `BCType` representation for the given `Type`
+BytecodeGenerationPtr<BCType> emitBCTypeImpl(
+ BytecodeGenerationContext* context,
+ Type* type)
+{
+ // A NULL type is interpreted as equivalent to `Void` for now.
+ if( !type )
+ {
+ return emitBCType(context, type, kIROp_VoidType);
+ }
+
+ if( auto basicType = type->As<BasicExpressionType>() )
+ {
+ switch(basicType->baseType)
+ {
+ case BaseType::Void: return emitBCType(context, type, kIROp_VoidType);
+ case BaseType::Bool: return emitBCType(context, type, kIROp_BoolType);
+ case BaseType::Int: return emitBCType(context, type, kIROp_Int32Type);
+ case BaseType::UInt: return emitBCType(context, type, kIROp_UInt32Type);
+ case BaseType::UInt64: return emitBCType(context, type, kIROp_UInt64Type);
+ case BaseType::Half: return emitBCType(context, type, kIROp_Float16Type);
+ case BaseType::Float: return emitBCType(context, type, kIROp_Float32Type);
+ case BaseType::Double: return emitBCType(context, type, kIROp_Float64Type);
+
+ default:
+ break;
+ }
+ }
+ else if( auto funcType = type->As<FuncType>() )
+ {
+ List<BytecodeGenerationPtr<uint8_t>> operands;
+
+ operands.Add(emitBCType(context, funcType->resultType).bitCast<uint8_t>());
+ UInt paramCount = funcType->getParamCount();
+ for(UInt pp = 0; pp < paramCount; ++pp)
+ {
+ operands.Add(emitBCType(context, funcType->getParamType(pp)).bitCast<uint8_t>());
+ }
+
+ return emitBCVarArgType(context, type, kIROp_FuncType, operands);
+ }
+ else if( auto ptrType = type->As<PtrType>() )
+ {
+ List<BytecodeGenerationPtr<uint8_t>> operands;
+ operands.Add(emitBCType(context, ptrType->getValueType()).bitCast<uint8_t>());
+ return emitBCVarArgType(context, type, kIROp_PtrType, operands);
+ }
+ else if( auto rwStructuredBufferType = type->As<HLSLRWStructuredBufferType>() )
+ {
+ List<BytecodeGenerationPtr<uint8_t>> operands;
+ operands.Add(emitBCType(context, rwStructuredBufferType->elementType).bitCast<uint8_t>());
+ return emitBCVarArgType(context, type, kIROp_readWriteStructuredBufferType, operands);
+ }
+ else if( auto structuredBufferType = type->As<HLSLStructuredBufferType>() )
+ {
+ List<BytecodeGenerationPtr<uint8_t>> operands;
+ operands.Add(emitBCType(context, structuredBufferType->elementType).bitCast<uint8_t>());
+ return emitBCVarArgType(context, type, kIROp_structuredBufferType, operands);
+ }
+
+
+ SLANG_UNEXPECTED("unimplemented");
+ return BytecodeGenerationPtr<BCType>();
+}
+
+BytecodeGenerationPtr<BCType> emitBCType(
+ BytecodeGenerationContext* context,
+ Type* type)
{
- Int globalID;
- if(context->shared->mapGlobalSymbolToGLobalID.TryGetValue(inst, globalID))
- return globalID;
+ auto canonical = type->GetCanonicalType();
+ UInt id = 0;
+ if(context->shared->mapTypeToID.TryGetValue(canonical, id))
+ {
+ return context->shared->bcTypes[id];
+ }
- SLANG_UNEXPECTED("no such ID");
+ BytecodeGenerationPtr<BCType> bcType = emitBCTypeImpl(context, canonical);
+ return bcType;
}
-uint32_t getTypeForGlobalSymbol(
+uint32_t getTypeID(
BytecodeGenerationContext* context,
- IRInst* inst)
+ Type* type)
+{
+ // We have a type, and we need to emit it (if we haven't
+ // already) and return its index in the global type table.
+ BytecodeGenerationPtr<BCType> bcType = emitBCType(context, type);
+ return bcType->id;
+}
+
+uint32_t getTypeIDForGlobalSymbol(
+ BytecodeGenerationContext* context,
+ IRValue* inst)
{
auto type = inst->getType();
if(!type)
return 0;
- return getIDForGlobalSymbol(context, type);
+ return getTypeID(context, type);
}
BytecodeGenerationPtr<char> allocateString(
@@ -442,7 +644,7 @@ BytecodeGenerationPtr<char> allocateString(
BytecodeGenerationPtr<char> tryGenerateNameForSymbol(
BytecodeGenerationContext* context,
- IRInst* inst)
+ IRGlobalValue* inst)
{
// TODO: this is gross, and the IR should probably have
// a more direct means of querying a name for a symbol.
@@ -462,9 +664,10 @@ BytecodeGenerationPtr<char> tryGenerateNameForSymbol(
return BytecodeGenerationPtr<char>();
}
+// Generate a `BCSymbol` that can represent a global value.
BytecodeGenerationPtr<BCSymbol> generateBytecodeSymbolForInst(
BytecodeGenerationContext* context,
- IRInst* inst)
+ IRGlobalValue* inst)
{
switch( inst->op )
{
@@ -474,7 +677,7 @@ BytecodeGenerationPtr<BCSymbol> generateBytecodeSymbolForInst(
BytecodeGenerationPtr<BCFunc> bcFunc = allocate<BCFunc>(context);
bcFunc->op = inst->op;
- bcFunc->typeGlobalID = getTypeForGlobalSymbol(context, inst);
+ bcFunc->typeID = getTypeIDForGlobalSymbol(context, inst);
BytecodeGenerationContext subContextStorage;
BytecodeGenerationContext* subContext = &subContextStorage;
@@ -515,7 +718,17 @@ BytecodeGenerationPtr<BCSymbol> generateBytecodeSymbolForInst(
{
UInt blockID = blockCounter++;
UInt paramCount = 0;
- for( auto ii = bb->firstChild; ii; ii = ii->nextInst )
+
+ for( auto pp = bb->getFirstParam(); pp; pp = pp->getNextParam() )
+ {
+ // A parameter always uses a register.
+ regCounter++;
+ //
+ // We also want to keep a count of the parameters themselves.
+ paramCount++;
+ }
+
+ for( auto ii = bb->getFirstInst(); ii; ii = ii->nextInst )
{
switch( ii->op )
{
@@ -528,15 +741,6 @@ BytecodeGenerationPtr<BCSymbol> generateBytecodeSymbolForInst(
}
break;
- case kIROp_Param:
- // A parameter always uses a register.
- regCounter++;
- //
- // We also want to keep a count of the parameters themselves.
- paramCount++;
- break;
-
-
case kIROp_Var:
// A `var` (`alloca`) node needs two registers:
// one to hold the actual storage, and another
@@ -573,30 +777,25 @@ BytecodeGenerationPtr<BCSymbol> generateBytecodeSymbolForInst(
// are always the first N registers in the overall list.
//
bcBlocks[blockID].params = bcRegs + regCounter;
- for( auto ii = bb->firstChild; ii; ii = ii->nextInst )
+ for( auto pp = bb->getFirstParam(); pp; pp = pp->getNextParam() )
{
- if(ii->op != kIROp_Param)
- continue;
-
Int localID = regCounter++;
- subContext->mapInstToLocalID.Add(ii, localID);
+ subContext->mapInstToLocalID.Add(pp, localID);
- bcRegs[localID].op = ii->op;
- bcRegs[localID].name = tryGenerateNameForSymbol(context, ii);
+ bcRegs[localID].op = pp->op;
+#if 0
+ bcRegs[localID].name = tryGenerateNameForSymbol(context, pp);
+#endif
bcRegs[localID].previousVarIndexPlusOne = localID;
- bcRegs[localID].typeGlobalID = getTypeForGlobalSymbol(context, ii);
+ bcRegs[localID].typeID = getTypeIDForGlobalSymbol(context, pp);
}
// Now loop over the non-parameter instructions and
// allocate actual register locations to them.
- for( auto ii = bb->firstChild; ii; ii = ii->nextInst )
+ for( auto ii = bb->getFirstInst(); ii; ii = ii->nextInst )
{
switch(ii->op)
{
- case kIROp_Param:
- // Already handled.
- break;
-
default:
// For an ordinary instruction with a result,
// allocate it here.
@@ -606,9 +805,11 @@ BytecodeGenerationPtr<BCSymbol> generateBytecodeSymbolForInst(
subContext->mapInstToLocalID.Add(ii, localID);
bcRegs[localID].op = ii->op;
+#if 0
bcRegs[localID].name = tryGenerateNameForSymbol(context, ii);
+#endif
bcRegs[localID].previousVarIndexPlusOne = localID;
- bcRegs[localID].typeGlobalID = getTypeForGlobalSymbol(context, ii);
+ bcRegs[localID].typeID = getTypeIDForGlobalSymbol(context, ii);
}
break;
@@ -624,30 +825,39 @@ BytecodeGenerationPtr<BCSymbol> generateBytecodeSymbolForInst(
subContext->mapInstToLocalID.Add(ii, localID);
bcRegs[localID].op = ii->op;
+#if 0
bcRegs[localID].name = tryGenerateNameForSymbol(context, ii);
+#endif
bcRegs[localID].previousVarIndexPlusOne = localID;
- bcRegs[localID].typeGlobalID = getTypeForGlobalSymbol(context, ii);
+ bcRegs[localID].typeID = getTypeIDForGlobalSymbol(context, ii);
bcRegs[localID+1].op = ii->op;
bcRegs[localID+1].previousVarIndexPlusOne = localID+1;
- bcRegs[localID+1].typeGlobalID = getIDForGlobalSymbol(context,
- ((IRPtrType*) ii->getType())->getValueType());
+ bcRegs[localID+1].typeID = getTypeID(context,
+ (ii->getType()->As<PtrType>())->getValueType());
}
break;
}
}
}
+ assert(regCounter == regCount);
// Now that we've allocated our blocks and our registers
// we can go through the actual process of emitting instructions. Hooray!
blockCounter = 0;
+
+ // Offset of each basic block from the start of the code
+ // for the current funciton.
+ List<UInt> blockOffsets;
for( auto bb = irFunc->getFirstBlock(); bb; bb = bb->getNextBlock() )
{
UInt blockID = blockCounter++;
- bcBlocks[blockID].code = getPtr<BCOp>(context);
+ // Get local bytecode offset for current block.
+ UInt blockOffset = subContext->currentBytecode.Count();
+ blockOffsets.Add( blockOffset );
- for( auto ii = bb->firstChild; ii; ii = ii->nextInst )
+ for( auto ii = bb->getFirstInst(); ii; ii = ii->nextInst )
{
// What we do with each instruction depends a bit on the
// kind of instruction it is.
@@ -671,6 +881,22 @@ BytecodeGenerationPtr<BCSymbol> generateBytecodeSymbolForInst(
}
}
+
+ // We've collected bytecode for the instruction stream
+ // into a sub-context, so we can now append that code.
+ UInt byteCount = subContext->currentBytecode.Count();
+ BytecodeGenerationPtr<uint8_t> bytes = allocateArray<uint8_t>(context, byteCount);
+ memcpy(&bytes[0], subContext->currentBytecode.Buffer(), byteCount);
+
+ // Now that we've allocated the storage, we can write
+ // the bytecode pointers into the blocks.
+ blockCounter = 0;
+ for( auto bb = irFunc->getFirstBlock(); bb; bb = bb->getNextBlock() )
+ {
+ UInt blockID = blockCounter++;
+ bcBlocks[blockID].code = bytes + blockOffsets[blockID];
+ }
+
// Finally, after emitting all the instructions we can
// build a table of global symbols taht need to be
// imported into the current function as constants.
@@ -689,6 +915,17 @@ BytecodeGenerationPtr<BCSymbol> generateBytecodeSymbolForInst(
}
break;
+ case kIROp_global_var:
+ {
+ auto bcVar = allocate<BCSymbol>(context);
+
+ bcVar->op = inst->op;
+ bcVar->typeID = getTypeID(context, inst->type);
+
+ return bcVar;
+ }
+ break;
+
default:
// Most instructions don't need a custom representation.
return BytecodeGenerationPtr<BCSymbol>();
@@ -699,126 +936,98 @@ BytecodeGenerationPtr<BCModule> generateBytecodeForModule(
BytecodeGenerationContext* context,
IRModule* irModule)
{
- // The module will get encoded much like a function,
- // and then that function will be "invoked" to load
- // the module.
+ // A module in the bytecode is mostly just a list of the
+ // global symbols in the module.
//
- auto bcModule = allocate<BCModule>(context);
- bcModule->op = irModule->op;
- bcModule->typeGlobalID = 0;
-
- // The logical function that the module representats
- // will only have a single block, so we can allocate it now.
+ // TODO: we need to be careful and recognize the distinction
+ // between the global symbols in the *AST* module, vs. those
+ // symbols which are effectively global in the *IR* module.
//
- auto bcBlock = allocate<BCBlock>(context);
- bcBlock->paramCount = 0;
- bcBlock->params = BytecodeGenerationPtr<BCReg>();
-
- bcModule->blockCount = 1;
- bcModule->blocks = bcBlock;
-
- // Because the module is the top-most level, there is
- // no need for it to have "constants" that represent
- // values imported from the next outer scope.
+ // We probably need to store these distinctly, since we
+ // need the AST global symbols for reflection, and then
+ // also to reconstruct the AST on load when importing a
+ // serialized module. We then need the global IR symbols
+ // to use when linking, to quickly resolve things without
+ // needing any semantic knowledge of nesting at the AST level.
//
- bcModule->constCount = 0;
- bcModule->consts = BytecodeGenerationPtr<BCConst>();
+ auto bcModule = allocate<BCModule>(context);
// We need to compute how many "registers" to allocate
// for the module, where the registers represent the
// values being computed at the global scope.
- UInt regCount = 0;
- for( auto inst = irModule->firstChild; inst; inst = inst->nextInst )
+ UInt symbolCount = 0;
+ for( auto gv : irModule->globalValues )
{
- if(!opHasResult(inst))
- continue;
-
- Int globalID = Int(regCount++);
+ Int globalID = Int(symbolCount++);
- context->shared->mapGlobalSymbolToGLobalID.Add(inst, globalID);
+ // Ensure that local code inside functions can see these symbols
+ BCConst bcConst;
+ bcConst.flavor = kBCConstFlavor_GlobalSymbol;
+ bcConst.id = globalID;
+ context->shared->mapValueToGlobal.Add(gv, bcConst);
// In the global scope, global IDs are also the local IDs
- context->mapInstToLocalID.Add(inst, globalID);
+ context->mapInstToLocalID.Add(gv, globalID);
}
- auto bcRegs = allocateArray<BCReg>(context, regCount);
+ auto bcSymbols = allocateArray<BCPtr<BCSymbol>>(context, symbolCount);
- bcModule->regCount = regCount;
- bcModule->regs = bcRegs;
+ bcModule->symbolCount = symbolCount;
+ bcModule->symbols = bcSymbols;
- // Now lets walk through and initialize all those bytecode
- // register representations so that we can use them.
- UInt regCounter= 0;
- for( auto inst = irModule->firstChild; inst; inst = inst->nextInst )
+ for( auto gv : irModule->globalValues )
{
- if(!opHasResult(inst))
- continue;
-
- UInt regIndex = *context->mapInstToLocalID.TryGetValue(inst);
+ UInt symbolIndex = *context->mapInstToLocalID.TryGetValue(gv);
- BytecodeGenerationPtr<char> name = tryGenerateNameForSymbol(context, inst);
-
- bcRegs[regIndex].op = inst->op;
- bcRegs[regIndex].name = name;
- bcRegs[regIndex].typeGlobalID = getTypeForGlobalSymbol(context, inst);
- bcRegs[regIndex].previousVarIndexPlusOne = regIndex;
- }
-
- // Some instructions represent "nested" symbols that will need
- // custom handling, and we will represent those here.
- List<BytecodeGenerationPtr<BCSymbol>> nestedSymbols;
- for( auto inst = irModule->firstChild; inst; inst = inst->nextInst )
- {
- UInt regIndex = *context->mapInstToLocalID.TryGetValue(inst);
-
- auto bcSymbol = generateBytecodeSymbolForInst(context, inst);
+ auto bcSymbol = generateBytecodeSymbolForInst(context, gv);
if (!bcSymbol.getPtr())
continue;
- UInt nestedSymbolID = nestedSymbols.Count();
- nestedSymbols.Add(bcSymbol);
-
- context->mapInstToNestedID.Add(inst, nestedSymbolID);
+ auto name = tryGenerateNameForSymbol(context, gv);
+ bcSymbol->name = name;
- bcSymbol->name = bcRegs[regIndex].name;
+ bcSymbols[symbolIndex] = bcSymbol;
}
- auto nestedSymbolCount = nestedSymbols.Count();
- auto bcNestedSymbols = allocateArray<BCPtr<BCSymbol>>(context, nestedSymbolCount);
+ // At this point we should have identified all the literals we need:
+ UInt constantCount = context->shared->constants.Count();
+ auto bcConstants = allocateArray<BCConstant>(context, constantCount);
+ bcModule->constantCount = constantCount;
+ bcModule->constants = bcConstants;
- bcModule->nestedSymbolCount = nestedSymbolCount;
- bcModule->nestedSymbols = bcNestedSymbols;
- for (UInt ii = 0; ii < nestedSymbolCount; ++ii)
+ for(UInt cc = 0; cc < constantCount; ++cc)
{
- bcNestedSymbols[ii] = nestedSymbols[ii];
- }
-
+ auto irConstant = (IRConstant*) context->shared->constants[cc];
+ bcConstants[cc].op = irConstant->op;
+ bcConstants[cc].typeID = getTypeID(context, irConstant->type);
- // Finally, we can go through and emit the actual code for
- // the initialization step.
- bcBlock->code = getPtr<BCOp>(context);
- for( auto inst = irModule->firstChild; inst; inst = inst->nextInst )
- {
- // Generate bytecode for global-scope inst
- generateBytecodeForInst(context, inst);
- }
- // Need to encode a terminator here, just to keep the encoding valid
- encodeUInt(context, kIROp_ReturnVoid);
+ switch(irConstant->op)
+ {
+ case kIROp_IntLit:
+ {
+ auto ptr = allocate<IRIntegerValue>(context);
+ *ptr = irConstant->u.intVal;
+ bcConstants[cc].ptr = ptr.bitCast<uint8_t>();
+ }
+ break;
-#if 0
+ default:
+ break;
+ }
- // Now we can go through and generate the bytecode object
- // that will represent each of these global symbols
+ }
- List<BytecodeGenerationPtr<BCSymbol>> globalSymbols;
+ // At this point we should have collected all the types we need:
+ UInt typeCount = context->shared->bcTypes.Count();
+ auto bcTypes = allocateArray<BCPtr<BCType>>(context, typeCount);
+ bcModule->typeCount = typeCount;
+ bcModule->types = bcTypes;
- for( auto inst = irModule->firstChild; inst; inst = inst->nextInst )
+ for(UInt tt = 0; tt < typeCount; ++tt)
{
- // Generate bytecode for global-scope inst
- auto globalSymbol = generateBytecodeForGlobalSymbol(context, inst);
- globalSymbols.Add(globalSymbol);
+ bcTypes[tt] = context->shared->bcTypes[tt];
}
-#endif
+
return bcModule;
}
@@ -833,10 +1042,6 @@ void generateBytecodeStream(
memcpy(header->magic, "slang\0bc", sizeof(header->magic));
header->version = 0;
- // HACK: ensure that a NULL pointer in an operand field can
- // be encoded.
- context->shared->mapGlobalSymbolToGLobalID.Add(nullptr, -1);
-
header->module = generateBytecodeForModule(context, irModule);
}
diff --git a/source/slang/bytecode.h b/source/slang/bytecode.h
index e073763cf..1ea16406f 100644
--- a/source/slang/bytecode.h
+++ b/source/slang/bytecode.h
@@ -65,6 +65,7 @@ struct BCPtr
}
operator T*() const { return getPtr(); }
+ T* operator->() const { return getPtr(); }
T* getPtr() const
{
@@ -73,9 +74,52 @@ struct BCPtr
}
};
-struct BCType
+// Representation of a "type-level" value in
+// the bytecode fiel. This corresponds to
+// the AST-level notion of a `Val`
+struct BCVal
{
+ // The opcode used to define this value
uint32_t op;
+
+ // The ID of the type within its module
+ uint32_t id;
+};
+
+struct BCType : BCVal
+{
+ // TODO: avoid having to encode this?
+ uint32_t argCount;
+
+ // type-specific operands follow
+
+ //
+
+ BCPtr<BCVal>* getArgs() { return (BCPtr<BCVal>*) (this +1); }
+
+ BCVal* getArg(UInt index) { return getArgs()[index]; }
+};
+
+struct BCPtrType : BCType
+{
+ BCPtr<BCType> valueType;
+};
+
+struct BCFuncType : BCType
+{
+ BCPtr<BCType> resultType;
+ BCPtr<BCType> paramTypes[1];
+
+ BCType* getResultType() { return resultType; }
+
+ UInt getParamCount() { return argCount - 1; }
+ BCType* getParamType(UInt index) { return paramTypes[index]; }
+};
+
+struct BCConstant : BCVal
+{
+ uint32_t typeID;
+ BCPtr<uint8_t> ptr;
};
struct BCSymbol
@@ -84,15 +128,9 @@ struct BCSymbol
// this symbol; used to categorize things
uint32_t op;
- // The type of the symbol is represent
- // as an index into the global-scope symbol
- // list of the module.
- //
- // Note: This currently precludes having
- // a register with a type that is not
- // statically determined, but that is
- // probably okay.
- uint32_t typeGlobalID;
+ // The index (in the module's type table)
+ // of the type of the symbol:
+ uint32_t typeID;
// The name of this symbol (which might
// be a mangled name at some point,
@@ -111,17 +149,18 @@ struct BCReg : BCSymbol
uint32_t previousVarIndexPlusOne;
};
+enum BCConstFlavor
+{
+ kBCConstFlavor_GlobalSymbol,
+ kBCConstFlavor_Constant,
+};
+
struct BCConst
{
- // The ID of the symbol in the global
- // scope that we are trying to refer
- // to.
- //
- // TODO: eventually, if we have general
- // nesting, then this might be the
- // entry in the outer scope that
- // is being referenced.
- uint32_t globalID;
+ // The flavor of bytecode constant we
+ // are dealing with.
+ uint32_t flavor;
+ uint32_t id;
};
struct BCBlock
@@ -153,16 +192,27 @@ struct BCFunc : BCSymbol
// but this would make the encoding less dense.
uint32_t constCount;
BCPtr<BCConst> consts;
-
- // Data for "nested" symbols (e.g., a function
- // nested inside this function).
- uint32_t nestedSymbolCount;
- BCPtr<BCPtr<BCSymbol>> nestedSymbols;
};
-// A module is encoded more or less like a function.
-struct BCModule : BCFunc
+struct BCModule
{
+ // The symbols (functions, global variables, etc.)
+ // that have been declared in the module.
+ uint32_t symbolCount;
+ BCPtr<BCPtr<BCSymbol>> symbols;
+
+ // The types that are used by this module, stored
+ // in a single array so that they can be conveniently
+ // mapped to another representation in one go.
+ //
+ // Instructions in a bytecode instruction sequence
+ // might reference these types by index.
+ uint32_t typeCount;
+ BCPtr<BCPtr<BCType>> types;
+
+ // True compile-time constants go here:
+ uint32_t constantCount;
+ BCPtr<BCConstant> constants;
};
struct BCHeader
diff --git a/source/slang/check.cpp b/source/slang/check.cpp
index 5c1f7380c..73d464d95 100644
--- a/source/slang/check.cpp
+++ b/source/slang/check.cpp
@@ -4546,26 +4546,21 @@ namespace Slang
// if this is still an invoke expression, test arguments passed to inout/out parameter are LValues
if(auto funcType = invoke->FunctionExpr->type->As<FuncType>())
{
- List<RefPtr<ParamDecl>> paramsStorage;
- List<RefPtr<ParamDecl>> * params = nullptr;
- if (auto func = funcType->declRef.getDecl())
+ UInt paramCount = funcType->getParamCount();
+ for (UInt pp = 0; pp < paramCount; ++pp)
{
- paramsStorage = func->GetParameters().ToArray();
- params = &paramsStorage;
- }
- if (params)
- {
- for (UInt i = 0; i < (*params).Count(); i++)
+ auto paramType = funcType->getParamType(pp);
+ if (auto outParamType = paramType->As<OutTypeBase>())
{
- if ((*params)[i]->HasModifier<OutModifier>())
+ if (pp < expr->Arguments.Count()
+ && !expr->Arguments[pp]->type.IsLeftValue)
{
- if (i < expr->Arguments.Count() && expr->Arguments[i]->type->AsBasicType() &&
- !expr->Arguments[i]->type.IsLeftValue)
+ if (!isRewriteMode())
{
- if (!isRewriteMode())
- {
- getSink()->diagnose(expr->Arguments[i], Diagnostics::argumentExpectedLValue, (*params)[i]->getName());
- }
+ getSink()->diagnose(
+ expr->Arguments[pp],
+ Diagnostics::argumentExpectedLValue,
+ pp);
}
}
}
diff --git a/source/slang/compiler.h b/source/slang/compiler.h
index 7ec812d63..1919699a1 100644
--- a/source/slang/compiler.h
+++ b/source/slang/compiler.h
@@ -15,6 +15,7 @@ namespace Slang
struct IncludeHandler;
class CompileRequest;
class ProgramLayout;
+ class PtrType;
enum class CompilerMode
{
@@ -369,6 +370,7 @@ namespace Slang
RefPtr<Type> errorType;
RefPtr<Type> initializerListType;
RefPtr<Type> overloadedType;
+ RefPtr<Type> irBasicBlockType;
Dictionary<int, RefPtr<Type>> builtinTypes;
Dictionary<String, Decl*> magicDecls;
@@ -388,6 +390,12 @@ namespace Slang
Type* getOverloadedType();
Type* getErrorType();
+ // Should not be used in front-end code
+ Type* getIRBasicBlockType();
+
+ // Construct pointer types on-demand
+ RefPtr<PtrType> getPtrType(RefPtr<Type> valueType);
+
SyntaxClass<RefObject> findSyntaxClass(Name* name);
Dictionary<Name*, SyntaxClass<RefObject> > mapNameToSyntaxClass;
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index 2e200f63a..3755c51c7 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -85,6 +85,17 @@ for (int tt = 0; tt < kBaseTypeCount; ++tt)
sb << "};\n";
}
+// Declare built-in pointer type
+// (eventually we can have the traditional syntax sugar for this)
+
+}}}}
+
+__generic<T>
+__magic_type(PtrType)
+struct Ptr
+{};
+
+${{{{
// Declare vector and matrix types
diff --git a/source/slang/core.meta.slang.h b/source/slang/core.meta.slang.h
index cf2052d3c..864196944 100644
--- a/source/slang/core.meta.slang.h
+++ b/source/slang/core.meta.slang.h
@@ -86,6 +86,18 @@ for (int tt = 0; tt < kBaseTypeCount; ++tt)
sb << "};\n";
}
+// Declare built-in pointer type
+// (eventually we can have the traditional syntax sugar for this)
+
+sb << "\n";
+sb << "\n";
+sb << "__generic<T>\n";
+sb << "__magic_type(PtrType)\n";
+sb << "struct Ptr\n";
+sb << "{};\n";
+sb << "\n";
+sb << "";
+
// Declare vector and matrix types
diff --git a/source/slang/emit.cpp b/source/slang/emit.cpp
index d4c1be706..64ac85bd4 100644
--- a/source/slang/emit.cpp
+++ b/source/slang/emit.cpp
@@ -93,6 +93,10 @@ struct SharedEmitContext
bool needHackSamplerForTexelFetch = false;
ExtensionUsageTracker extensionUsageTracker;
+
+ Dictionary<IRValue*, UInt> mapIRValueToID;
+
+ HashSet<Decl*> irDeclsVisited;
};
struct EmitContext
@@ -1025,6 +1029,9 @@ struct EmitVisitor
UNEXPECTED(GenericDeclRefType);
UNEXPECTED(InitializerListType);
+ UNEXPECTED(IRBasicBlockType);
+ UNEXPECTED(PtrType);
+
#undef UNEXPECTED
void visitNamedExpressionType(NamedExpressionType* type, TypeEmitArg const& arg)
@@ -1231,6 +1238,17 @@ struct EmitVisitor
EmitType(type, SourceLoc(), name, SourceLoc());
}
+ void EmitType(RefPtr<Type> type, String const& name)
+ {
+ // HACK: the rest of the code wants a `Name`,
+ // so we'll create one for a bit...
+ Name tempName;
+ tempName.text = name;
+
+ EmitType(type, SourceLoc(), &tempName, SourceLoc());
+ }
+
+
void EmitType(RefPtr<Type> type)
{
emitTypeImpl(type, nullptr);
@@ -3942,8 +3960,34 @@ emitDeclImpl(decl, nullptr);
// IR-level emit logc
- String getName(IRInst* inst)
+ UInt getID(IRValue* value)
+ {
+ auto& mapIRValueToID = context->shared->mapIRValueToID;
+
+ UInt id = 0;
+ if (mapIRValueToID.TryGetValue(value, id))
+ return id;
+
+ id = mapIRValueToID.Count() + 1;
+ mapIRValueToID.Add(value, id);
+ return id;
+ }
+
+ String getName(IRValue* inst)
{
+ switch(inst->op)
+ {
+ case kIROp_decl_ref:
+ {
+ auto irDeclRef = (IRDeclRef*) inst;
+ return getText(irDeclRef->declRef.GetName());
+ }
+ break;
+
+ default:
+ break;
+ }
+
if(auto decoration = inst->findDecoration<IRHighLevelDeclDecoration>())
{
auto decl = decoration->decl;
@@ -3957,7 +4001,7 @@ emitDeclImpl(decl, nullptr);
StringBuilder sb;
sb << "_S";
- sb << inst->id;
+ sb << getID(inst);
return sb.ProduceString();
}
@@ -4035,7 +4079,7 @@ emitDeclImpl(decl, nullptr);
}
-
+#if 0
void emitIRSimpleType(
EmitContext* context,
IRType* type)
@@ -4153,12 +4197,14 @@ emitDeclImpl(decl, nullptr);
}
}
+#endif
CodeGenTarget getTarget(EmitContext* context)
{
return context->shared->target;
}
+#if 0
void emitGLSLTypePrefix(
EmitContext* context,
IRType* type)
@@ -4194,44 +4240,9 @@ emitDeclImpl(decl, nullptr);
break;
}
}
+#endif
- void emitIRVectorType(
- EmitContext* context,
- IRVectorType* type)
- {
- switch(getTarget(context))
- {
- case CodeGenTarget::GLSL:
- // HLSL style: `<elementTypePrefix>vec<elementCount>`
- // e.g., `ivec4`
- //
- emitGLSLTypePrefix(context, type->getElementType());
- emit("vec");
- emitIRSimpleValue(context, type->getElementCount());
- break;
-
- default:
- // HLSL style: `<elementTypeName><elementCount>`
- // e.g., `int4`
- //
- emitIRSimpleType(context, type->getElementType());
- emitIRSimpleValue(context, type->getElementCount());
- break;
- }
- }
-
- void emitIRMatrixType(
- EmitContext* context,
- IRMatrixType* type)
- {
- // TODO: this is a GLSL-vs-HLSL decision point
-
- emitIRSimpleType(context, type->getElementType());
- emitIRSimpleValue(context, type->getRowCount());
- emit("x");
- emitIRSimpleValue(context, type->getColumnCount());
- }
-
+#if 0
void emitIRType(
EmitContext* context,
IRType* type,
@@ -4287,10 +4298,11 @@ emitDeclImpl(decl, nullptr);
{
emitIRType(context, type, (IRDeclaratorInfo*) nullptr);
}
+#endif
bool shouldFoldIRInstIntoUseSites(
EmitContext* context,
- IRInst* inst)
+ IRValue* inst)
{
// Certain opcodes should always be folded in
switch( inst->op )
@@ -4306,27 +4318,24 @@ emitDeclImpl(decl, nullptr);
return true;
}
- // Certain *types* will usually want to be folded in
+ // Certain *types* will usually want to be folded in,
+ // because they aren't allowed as types for temporary
+ // variables.
auto type = inst->getType();
- switch (type->op)
+ if(type->As<UniformParameterBlockType>())
{
- case kIROp_ConstantBufferType:
- case kIROp_TextureBufferType:
// TODO: we need to be careful here, because
// HLSL shader model 6 allows these as explicit
// types.
return true;
-
- case kIROp_TextureType:
+ }
+ else if(type->As<TextureTypeBase>())
+ {
// GLSL doesn't allow texture/resource types to
// be used as first-class values, so we need
// to fold them into their use sites in all cases
if(getTarget(context) == CodeGenTarget::GLSL)
return true;
- break;
-
- default:
- break;
}
// By default we will *not* fold things into their use sites.
@@ -4335,20 +4344,16 @@ emitDeclImpl(decl, nullptr);
bool isDerefBaseImplicit(
EmitContext* context,
- IRInst* inst)
+ IRValue* inst)
{
auto type = inst->getType();
- switch (type->op)
+
+ if(type->As<UniformParameterBlockType>())
{
- case kIROp_ConstantBufferType:
- case kIROp_TextureBufferType:
// TODO: we need to be careful here, because
// HLSL shader model 6 allows these as explicit
// types.
return true;
-
- default:
- break;
}
return false;
@@ -4358,7 +4363,7 @@ emitDeclImpl(decl, nullptr);
void emitIROperand(
EmitContext* context,
- IRInst* inst)
+ IRValue* inst)
{
if( shouldFoldIRInstIntoUseSites(context, inst) )
{
@@ -4380,8 +4385,8 @@ emitDeclImpl(decl, nullptr);
EmitContext* context,
IRInst* inst)
{
- UInt argCount = inst->argCount - 1;
- IRUse* args = inst->getArgs() + 1;
+ UInt argCount = inst->argCount;
+ IRUse* args = inst->getArgs();
emit("(");
for(UInt aa = 0; aa < argCount; ++aa)
@@ -4392,12 +4397,35 @@ emitDeclImpl(decl, nullptr);
emit(")");
}
+ void emitIRType(
+ EmitContext* context,
+ IRType* type,
+ String const& name)
+ {
+ EmitType(type, name);
+ }
+
+ void emitIRType(
+ EmitContext* context,
+ IRType* type,
+ Name* name)
+ {
+ EmitType(type, name);
+ }
+
+ void emitIRType(
+ EmitContext* context,
+ IRType* type)
+ {
+ EmitType(type);
+ }
+
void emitIRInstResultDecl(
EmitContext* context,
IRInst* inst)
{
auto type = inst->getType();
- if(!type || type->op == kIROp_VoidType)
+ if(!type)
return;
emitIRType(context, type, getName(inst));
@@ -4406,9 +4434,10 @@ emitDeclImpl(decl, nullptr);
void emitIRInstExpr(
EmitContext* context,
- IRInst* inst)
+ IRValue* value)
{
- switch(inst->op)
+ IRInst* inst = (IRInst*) value;
+ switch(value->op)
{
case kIROp_IntLit:
case kIROp_FloatLit:
@@ -4418,13 +4447,13 @@ emitDeclImpl(decl, nullptr);
case kIROp_Construct:
// Simple constructor call
- if( inst->getArgCount() == 2 && getTarget(context) == CodeGenTarget::HLSL)
+ if( inst->getArgCount() == 1 && getTarget(context) == CodeGenTarget::HLSL)
{
// Need to emit as cast for HLSL
emit("(");
emitIRType(context, inst->getType());
emit(") ");
- emitIROperand(context, inst->getArg(1));
+ emitIROperand(context, inst->getArg(0));
}
else
{
@@ -4446,7 +4475,7 @@ emitDeclImpl(decl, nullptr);
emitIRType(context, inst->getType());
}
emit("(");
- emitIROperand(context, inst->getArg(1));
+ emitIROperand(context, inst->getArg(0));
emit(")");
break;
@@ -4480,9 +4509,9 @@ emitDeclImpl(decl, nullptr);
#define CASE(OPCODE, OP) \
case OPCODE: \
- emitIROperand(context, inst->getArg(1)); \
+ emitIROperand(context, inst->getArg(0)); \
emit("" #OP " "); \
- emitIROperand(context, inst->getArg(2)); \
+ emitIROperand(context, inst->getArg(1)); \
break
CASE(kIROp_Add, +);
@@ -4512,7 +4541,7 @@ emitDeclImpl(decl, nullptr);
case kIROp_Not:
{
- if (inst->getType()->op == kIROp_BoolType)
+ if (inst->getType()->Equals(getSession()->getBoolType()))
{
emit("!");
}
@@ -4520,73 +4549,72 @@ emitDeclImpl(decl, nullptr);
{
emit("~");
}
- emitIROperand(context, inst->getArg(1));
+ emitIROperand(context, inst->getArg(0));
}
break;
case kIROp_Sample:
- // argument 0 is the instruction's type
- emitIROperand(context, inst->getArg(1));
+ emitIROperand(context, inst->getArg(0));
emit(".Sample(");
- emitIROperand(context, inst->getArg(2));
+ emitIROperand(context, inst->getArg(1));
emit(", ");
- emitIROperand(context, inst->getArg(3));
+ emitIROperand(context, inst->getArg(2));
emit(")");
break;
case kIROp_SampleGrad:
// argument 0 is the instruction's type
- emitIROperand(context, inst->getArg(1));
+ emitIROperand(context, inst->getArg(0));
emit(".SampleGrad(");
+ emitIROperand(context, inst->getArg(1));
+ emit(", ");
emitIROperand(context, inst->getArg(2));
emit(", ");
emitIROperand(context, inst->getArg(3));
emit(", ");
emitIROperand(context, inst->getArg(4));
- emit(", ");
- emitIROperand(context, inst->getArg(5));
emit(")");
break;
case kIROp_Load:
// TODO: this logic will really only work for a simple variable reference...
- emitIROperand(context, inst->getArg(1));
+ emitIROperand(context, inst->getArg(0));
break;
case kIROp_Store:
// TODO: this logic will really only work for a simple variable reference...
- emitIROperand(context, inst->getArg(1));
+ emitIROperand(context, inst->getArg(0));
emit(" = ");
- emitIROperand(context, inst->getArg(2));
+ emitIROperand(context, inst->getArg(1));
break;
case kIROp_Call:
{
- emitIROperand(context, inst->getArg(1));
+ emitIROperand(context, inst->getArg(0));
emit("(");
- UInt argCount = inst->getArgCount() - 2;
+ UInt argCount = inst->getArgCount() - 1;
for( UInt aa = 0; aa < argCount; ++aa )
{
if(aa != 0) emit(", ");
- emitIROperand(context, inst->getArg(aa + 2));
+ emitIROperand(context, inst->getArg(aa + 1));
}
emit(")");
}
break;
case kIROp_BufferLoad:
- emitIROperand(context, inst->getArg(1));
+ emitIROperand(context, inst->getArg(0));
emit("[");
- emitIROperand(context, inst->getArg(2));
+ emitIROperand(context, inst->getArg(1));
emit("]");
break;
case kIROp_BufferStore:
- emitIROperand(context, inst->getArg(1));
+ emitIROperand(context, inst->getArg(0));
emit("[");
- emitIROperand(context, inst->getArg(2));
+ emitIROperand(context, inst->getArg(1));
emit("] = ");
- emitIROperand(context, inst->getArg(3));
+ emitIROperand(context, inst->getArg(2));
break;
case kIROp_GroupMemoryBarrierWithGroupSync:
@@ -4595,9 +4623,9 @@ emitDeclImpl(decl, nullptr);
case kIROp_getElement:
case kIROp_getElementPtr:
- emitIROperand(context, inst->getArg(1));
+ emitIROperand(context, inst->getArg(0));
emit("[");
- emitIROperand(context, inst->getArg(2));
+ emitIROperand(context, inst->getArg(1));
emit("]");
break;
@@ -4605,9 +4633,9 @@ emitDeclImpl(decl, nullptr);
case kIROp_Mul_Matrix_Vector:
case kIROp_Mul_Matrix_Matrix:
emit("mul(");
- emitIROperand(context, inst->getArg(1));
+ emitIROperand(context, inst->getArg(0));
emit(", ");
- emitIROperand(context, inst->getArg(2));
+ emitIROperand(context, inst->getArg(1));
emit(")");
break;
@@ -4619,7 +4647,7 @@ emitDeclImpl(decl, nullptr);
UInt elementCount = ii->getElementCount();
for (UInt ee = 0; ee < elementCount; ++ee)
{
- IRInst* irElementIndex = ii->getElementIndex(ee);
+ IRValue* irElementIndex = ii->getElementIndex(ee);
assert(irElementIndex->op == kIROp_IntLit);
IRConstant* irConst = (IRConstant*)irElementIndex;
@@ -4658,7 +4686,7 @@ emitDeclImpl(decl, nullptr);
case kIROp_Var:
{
auto ptrType = inst->getType();
- auto valType = ((IRPtrType*)ptrType)->getValueType();
+ auto valType = ((PtrType*)ptrType)->getValueType();
auto name = getName(inst);
emitIRType(context, valType, name);
@@ -4689,14 +4717,14 @@ emitDeclImpl(decl, nullptr);
{
auto ii = (IRSwizzleSet*)inst;
emitIRInstResultDecl(context, inst);
- emitIROperand(context, inst->getArg(1));
+ emitIROperand(context, inst->getArg(0));
emit(";\n");
emitIROperand(context, inst);
emit(".");
UInt elementCount = ii->getElementCount();
for (UInt ee = 0; ee < elementCount; ++ee)
{
- IRInst* irElementIndex = ii->getElementIndex(ee);
+ IRValue* irElementIndex = ii->getElementIndex(ee);
assert(irElementIndex->op == kIROp_IntLit);
IRConstant* irConst = (IRConstant*)irElementIndex;
@@ -4707,34 +4735,16 @@ emitDeclImpl(decl, nullptr);
emit(kComponents[elementIndex]);
}
emit(" = ");
- emitIROperand(context, inst->getArg(2));
+ emitIROperand(context, inst->getArg(1));
emit(";\n");
}
break;
-
-#if 0
- case kIROp_unconditionalBranch:
- emit("// unconditionalBranch ");
- emitIROperand(context, inst->getArg(1));
- emit("\n");
- break;
-
- case kIROp_conditionalBranch:
- emit("// conditionalBranch ");
- emitIROperand(context, inst->getArg(1));
- emit(" ");
- emitIROperand(context, inst->getArg(2));
- emit(" ");
- emitIROperand(context, inst->getArg(3));
- emit("\n");
- break;
-#endif
}
}
void emitIRSemantics(
EmitContext* context,
- IRInst* inst)
+ IRValue* inst)
{
auto decoration = inst->findDecoration<IRHighLevelDeclDecoration>();
if( decoration )
@@ -4745,7 +4755,7 @@ emitDeclImpl(decl, nullptr);
VarLayout* getVarLayout(
EmitContext* context,
- IRInst* var)
+ IRValue* var)
{
auto decoration = var->findDecoration<IRLayoutDecoration>();
if (!decoration)
@@ -4756,7 +4766,7 @@ emitDeclImpl(decl, nullptr);
void emitIRLayoutSemantics(
EmitContext* context,
- IRInst* inst,
+ IRValue* inst,
char const* uniformSemanticSpelling = "register")
{
auto layout = getVarLayout(context, inst);
@@ -4784,9 +4794,9 @@ emitDeclImpl(decl, nullptr);
while(block != end)
{
// Start by emitting the non-terminator instructions in the block.
- auto terminator = block->lastChild;
+ auto terminator = block->getLastInst();
assert(isTerminatorInst(terminator));
- for (auto inst = block->firstChild; inst != terminator; inst = inst->nextInst)
+ for (auto inst = block->getFirstInst(); inst != terminator; inst = inst->nextInst)
{
emitIRInst(context, inst);
}
@@ -5095,10 +5105,8 @@ emitDeclImpl(decl, nullptr);
// encoded as a parameter of pointer type, so
// we need to decode that here.
//
- if( paramType->op == kIROp_PtrType )
+ if( auto ptrType = paramType->As<PtrType>() )
{
- auto ptrType = (IRPtrType*) paramType;
-
// TODO: we need a way to distinguish `out`
// from `inout`. The easiest way to do
// that might be to have each be a distinct
@@ -5243,7 +5251,7 @@ emitDeclImpl(decl, nullptr);
}
}
-
+#if 0
void emitIRStruct(
EmitContext* context,
IRStructDecl* structType)
@@ -5263,6 +5271,7 @@ emitDeclImpl(decl, nullptr);
}
emit("};\n");
}
+#endif
void emitIRVarModifiers(
EmitContext* context,
@@ -5317,9 +5326,9 @@ emitDeclImpl(decl, nullptr);
}
void emitIRParameterBlock(
- EmitContext* context,
- IRVar* varDecl,
- IRUniformBufferType* type)
+ EmitContext* context,
+ IRGlobalVar* varDecl,
+ UniformParameterBlockType* type)
{
emit("cbuffer ");
emit(getName(varDecl));
@@ -5341,17 +5350,15 @@ emitDeclImpl(decl, nullptr);
typeLayout = parameterBlockTypeLayout->elementTypeLayout;
}
- switch( elementType->op )
+ if(auto declRefType = elementType->As<DeclRefType>())
{
- case kIROp_StructType:
+ if(auto structDeclRef = declRefType->declRef.As<StructDecl>())
{
- auto structType = (IRStructDecl*) elementType;
-
auto structTypeLayout = typeLayout.As<StructTypeLayout>();
assert(structTypeLayout);
UInt fieldIndex = 0;
- for(auto ff = structType->getFirstField(); ff; ff = ff->getNextField())
+ for(auto ff : GetFields(structDeclRef))
{
// TODO: need a plan to deal with the case where the IR-level
// `struct` type might not match the high-level type, so that
@@ -5367,19 +5374,18 @@ emitDeclImpl(decl, nullptr);
emitIRVarModifiers(context, fieldLayout);
- auto fieldType = ff->getFieldType();
- emitIRType(context, fieldType, getName(ff));
+ auto fieldType = GetType(ff);
+ emitIRType(context, fieldType, ff.GetName());
emitHLSLParameterBlockFieldLayoutSemantics(layout, fieldLayout);
emit(";\n");
}
}
- break;
-
- default:
+ }
+ else
+ {
emit("/* unexpected */");
- break;
}
emit("}\n");
@@ -5391,8 +5397,9 @@ emitDeclImpl(decl, nullptr);
{
auto allocatedType = varDecl->getType();
auto varType = allocatedType->getValueType();
- auto addressSpace = allocatedType->getAddressSpace();
+// auto addressSpace = allocatedType->getAddressSpace();
+#if 0
switch( varType->op )
{
case kIROp_ConstantBufferType:
@@ -5403,6 +5410,7 @@ emitDeclImpl(decl, nullptr);
default:
break;
}
+#endif
// Need to emit appropriate modifiers here.
@@ -5410,6 +5418,7 @@ emitDeclImpl(decl, nullptr);
emitIRVarModifiers(context, layout);
+#if 0
switch (addressSpace)
{
default:
@@ -5419,6 +5428,51 @@ emitDeclImpl(decl, nullptr);
emit("groupshared ");
break;
}
+#endif
+
+ emitIRType(context, varType, getName(varDecl));
+
+ emitIRSemantics(context, varDecl);
+
+ emitIRLayoutSemantics(context, varDecl);
+
+ emit(";\n");
+ }
+
+ void emitIRGlobalVar(
+ EmitContext* context,
+ IRGlobalVar* varDecl)
+ {
+ auto allocatedType = varDecl->getType();
+ auto varType = allocatedType->getValueType();
+// auto addressSpace = allocatedType->getAddressSpace();
+
+ if (auto paramBlockType = varType->As<UniformParameterBlockType>())
+ {
+ emitIRParameterBlock(
+ context,
+ varDecl,
+ paramBlockType);
+ return;
+ }
+
+ // Need to emit appropriate modifiers here.
+
+ auto layout = getVarLayout(context, varDecl);
+
+ emitIRVarModifiers(context, layout);
+
+#if 0
+ switch (addressSpace)
+ {
+ default:
+ break;
+
+ case kIRAddressSpace_GroupShared:
+ emit("groupshared ");
+ break;
+ }
+#endif
emitIRType(context, varType, getName(varDecl));
@@ -5431,7 +5485,7 @@ emitDeclImpl(decl, nullptr);
void emitIRGlobalInst(
EmitContext* context,
- IRInst* inst)
+ IRGlobalValue* inst)
{
// TODO: need to be able to `switch` on the IR opcode here,
// so there is some work to be done.
@@ -5441,9 +5495,15 @@ emitDeclImpl(decl, nullptr);
emitIRFunc(context, (IRFunc*) inst);
break;
+ case kIROp_global_var:
+ emitIRGlobalVar(context, (IRGlobalVar*) inst);
+ break;
+
+#if 0
case kIROp_StructType:
emitIRStruct(context, (IRStructDecl*) inst);
break;
+#endif
case kIROp_Var:
emitIRVar(context, (IRVar*) inst);
@@ -5454,13 +5514,154 @@ emitDeclImpl(decl, nullptr);
}
}
+ void ensureStructDecl(
+ EmitContext* context,
+ DeclRef<StructDecl> declRef)
+ {
+ // TODO: Eventually need to deal with the case where
+ // we have user-defined generic types.
+ //
+ auto decl = declRef.getDecl();
+
+ if(context->shared->irDeclsVisited.Contains(decl))
+ return;
+
+ context->shared->irDeclsVisited.Add(decl);
+
+ // First emit any types used by fields of this type
+ for( auto ff : GetFields(declRef) )
+ {
+ if(ff.getDecl()->HasModifier<HLSLStaticModifier>())
+ continue;
+
+ auto fieldType = GetType(ff);
+ emitIRUsedType(context, fieldType);
+ }
+
+ Emit("struct ");
+ emit(declRef.GetName());
+ Emit("\n{\n");
+ for( auto ff : GetFields(declRef) )
+ {
+ if(ff.getDecl()->HasModifier<HLSLStaticModifier>())
+ continue;
+
+ auto fieldType = GetType(ff);
+ emitIRType(context, fieldType, ff.GetName());
+
+ EmitSemantics(ff.getDecl());
+
+ emit(";\n");
+ }
+ Emit("};\n");
+ }
+
+ // A type is going to be used by the IR, so
+ // make sure that we have emitted whatever
+ // it needs.
+ void emitIRUsedType(
+ EmitContext* context,
+ Type* type)
+ {
+ if(type->As<BasicExpressionType>())
+ {}
+ else if(type->As<VectorExpressionType>())
+ {}
+ else if(type->As<MatrixExpressionType>())
+ {}
+ else if(auto arrayType = type->As<ArrayExpressionType>())
+ {
+ emitIRUsedType(context, arrayType->baseType);
+ }
+ else if( auto textureType = type->As<TextureTypeBase>() )
+ {
+ emitIRUsedType(context, textureType->elementType);
+ }
+ else if( auto genericType = type->As<BuiltinGenericType>() )
+ {
+ emitIRUsedType(context, genericType->elementType);
+ }
+ else if( auto ptrType = type->As<PtrType>() )
+ {
+ emitIRUsedType(context, ptrType->getValueType());
+ }
+ else if(type->As<SamplerStateType>() )
+ {
+ }
+ else if( auto declRefType = type->As<DeclRefType>() )
+ {
+ auto declRef = declRefType->declRef;
+ auto decl = declRef.getDecl();
+
+ if(decl->HasModifier<BuiltinTypeModifier>()
+ || decl->HasModifier<MagicTypeModifier>())
+ {
+ return;
+ }
+
+ if( auto structDeclRef = declRef.As<StructDecl>() )
+ {
+ //
+ ensureStructDecl(context, structDeclRef);
+ }
+ }
+ else
+ {}
+ }
+
+ void emitIRUsedTypesForValue(
+ EmitContext* context,
+ IRValue* value)
+ {
+ if(!value) return;
+ switch( value->op )
+ {
+ case kIROp_Func:
+ {
+ auto irFunc = (IRFunc*) value;
+ emitIRUsedType(context, irFunc->getResultType());
+ for( auto bb = irFunc->getFirstBlock(); bb; bb = bb->getNextBlock() )
+ {
+ for( auto pp = bb->getFirstParam(); pp; pp = pp->getNextParam() )
+ {
+ emitIRUsedTypesForValue(context, pp);
+ }
+
+ for( auto ii = bb->getFirstInst(); ii; ii = ii->nextInst )
+ {
+ emitIRUsedTypesForValue(context, ii);
+ }
+ }
+ }
+ break;
+
+ default:
+ {
+ emitIRUsedType(context, value->type);
+ }
+ break;
+ }
+ }
+
+ void emitIRUsedTypesForModule(
+ EmitContext* context,
+ IRModule* module)
+ {
+ for (auto gv : module->globalValues)
+ {
+ emitIRUsedTypesForValue(context, gv);
+ }
+ }
+
void emitIRModule(
EmitContext* context,
IRModule* module)
{
- for(auto ii = module->firstChild; ii; ii = ii->nextInst )
+ emitIRUsedTypesForModule(context, module);
+
+ for (auto gv : module->globalValues)
{
- emitIRGlobalInst(context, ii);
+ emitIRGlobalInst(context, gv);
}
}
diff --git a/source/slang/ir-inst-defs.h b/source/slang/ir-inst-defs.h
index 60dc353ac..c11d66571 100644
--- a/source/slang/ir-inst-defs.h
+++ b/source/slang/ir-inst-defs.h
@@ -22,20 +22,20 @@ INST(arrayType, Array, 2, 0)
INST(BoolType, Bool, 0, 0)
-INST(Float16Type, Float16, 0, 0)
-INST(Float32Type, Float32, 0, 0)
-INST(Float64Type, Float64, 0, 0)
+INST(Float16Type, Float16, 0, 0)
+INST(Float32Type, Float32, 0, 0)
+INST(Float64Type, Float64, 0, 0)
// Signed integer types.
// Note that `IntPtr` represents a pointer-sized integer type,
// and will end up being equivalent to either `Int32` or `Int64`
// when it comes time to actually generate code.
//
-INST(Int8Type, Int8, 0, 0)
-INST(Int16Type, Int16, 0, 0)
-INST(Int32Type, Int32, 0, 0)
-INST(IntPtrType, IntPtr, 0, 0)
-INST(Int64Type, Int64, 0, 0)
+INST(Int8Type, Int8, 0, 0)
+INST(Int16Type, Int16, 0, 0)
+INST(Int32Type, Int32, 0, 0)
+INST(IntPtrType, IntPtr, 0, 0)
+INST(Int64Type, Int64, 0, 0)
// Unlike a lot of other IRs, we retain a distinction between
// signed and unsigned integer types, simply because many of
@@ -52,11 +52,11 @@ INST(Int64Type, Int64, 0, 0)
// or else we want to keep using the orignal types, but need
// to cast around any ordinary math operations on signed types.
//
-INST(UInt8Type, Int8, 0, 0)
-INST(UInt16Type, Int16, 0, 0)
-INST(UInt32Type, Int32, 0, 0)
-INST(UIntPtrType, IntPtr, 0, 0)
-INST(UInt64Type, Int64, 0, 0)
+INST(UInt8Type, Int8, 0, 0)
+INST(UInt16Type, Int16, 0, 0)
+INST(UInt32Type, Int32, 0, 0)
+INST(UIntPtrType, IntPtr, 0, 0)
+INST(UInt64Type, Int64, 0, 0)
// A user-defined structure declaration at the IR level.
// Unlike in the AST where there is a distinction between
@@ -94,6 +94,7 @@ INST(GenericParameterType, GenericParameterType, 1, 0)
INST(boolConst, boolConst, 0, 0)
INST(IntLit, integer_constant, 0, 0)
INST(FloatLit, float_constant, 0, 0)
+INST(decl_ref, decl_ref, 0, 0)
INST(Construct, construct, 0, 0)
INST(Call, call, 1, 0)
@@ -102,6 +103,8 @@ INST(Module, module, 0, PARENT)
INST(Func, func, 0, PARENT)
INST(Block, block, 0, PARENT)
+INST(global_var, global_var, 0, 0)
+
INST(Param, param, 0, 0)
INST(StructField, field, 0, 0)
INST(Var, var, 0, 0)
diff --git a/source/slang/ir-insts.h b/source/slang/ir-insts.h
index 1e9638ade..74151adaa 100644
--- a/source/slang/ir-insts.h
+++ b/source/slang/ir-insts.h
@@ -10,6 +10,7 @@
#include "compiler.h"
#include "ir.h"
+#include "syntax.h"
namespace Slang {
@@ -62,65 +63,22 @@ struct IRLoopControlDecoration : IRDecoration
IRLoopControl mode;
};
-struct IRMangledNameDecoration : IRDecoration
-{
- enum { kDecorationOp = kIRDecorationOp_MangledName };
-
- String mangledName;
-};
-
-// Types
-
-struct IRVectorType : IRType
-{
- IRUse elementType;
- IRUse elementCount;
-
- IRType* getElementType() { return (IRType*) elementType.usedValue; }
- IRInst* getElementCount() { return elementCount.usedValue; }
-};
-
-struct IRMatrixType : IRType
-{
- IRUse elementType;
- IRUse rowCount;
- IRUse columnCount;
-
- IRType* getElementType() { return (IRType*) elementType.usedValue; }
- IRInst* getRowCount() { return rowCount.usedValue; }
- IRInst* getColumnCount() { return columnCount.usedValue; }
-};
-
-struct IRArrayType : IRType
-{
- IRUse elementType;
- IRUse elementCount;
-
- IRType* getElementType() { return (IRType*) elementType.usedValue; }
- IRInst* getElementCount() { return elementCount.usedValue; }
-};
+//
-struct IRGenericParameterType : IRType
+// An IR node to represent a reference to an AST-level
+// declaration.
+struct IRDeclRef : IRValue
{
- IRUse index;
+ DeclRefBase declRef;
};
+//
-struct IRFuncType : IRType
-{
- IRUse resultType;
- // parameter tyeps are varargs...
-
- IRType* getResultType() { return (IRType*) resultType.usedValue; }
- UInt getParamCount()
- {
- return getArgCount() - 2;
- }
- IRType* getParamType(UInt index)
- {
- return (IRType*) getArg(2 + index);
- }
-};
+// TODO(tfoley): the IR-level representation of
+// pointers had an "address space" field that the
+// current AST level lacks. This capability was
+// used to represent `groupshared` allocation,
+// so we probably need to find an alternative.
// Address spaces for IR pointers
enum IRAddressSpace : UInt
@@ -132,6 +90,7 @@ enum IRAddressSpace : UInt
kIRAddressSpace_GroupShared,
};
+#if 0
struct IRPtrType : IRType
{
IRUse valueType;
@@ -145,31 +104,7 @@ struct IRPtrType : IRType
((IRConstant*)addressSpace.usedValue)->u.intVal);
}
};
-
-struct IRTextureType : IRType
-{
- IRUse flavor;
- IRUse elementType;
-
- IRIntegerValue getFlavor() { return ((IRConstant*) flavor.usedValue)->u.intVal; }
- IRType* getElementType() { return (IRType*) elementType.usedValue; }
-};
-
-struct IRBufferType : IRType
-{
- IRUse elementType;
- IRType* getElementType() { return (IRType*) elementType.usedValue; }
-};
-
-
-struct IRUniformBufferType : IRType
-{
- IRUse elementType;
- IRType* getElementType() { return (IRType*) elementType.usedValue; }
-};
-
-struct IRConstantBufferType : IRUniformBufferType {};
-struct IRTextureBufferType : IRUniformBufferType {};
+#endif
struct IRCall : IRInst
{
@@ -187,14 +122,13 @@ struct IRStore : IRInst
IRUse val;
};
-struct IRStructField;
struct IRFieldExtract : IRInst
{
IRUse base;
IRUse field;
- IRInst* getBase() { return base.usedValue; }
- IRStructField* getField() { return (IRStructField*) field.usedValue; }
+ IRValue* getBase() { return base.usedValue; }
+ IRValue* getField() { return field.usedValue; }
};
struct IRFieldAddress : IRInst
@@ -202,8 +136,8 @@ struct IRFieldAddress : IRInst
IRUse base;
IRUse field;
- IRInst* getBase() { return base.usedValue; }
- IRStructField* getField() { return (IRStructField*) field.usedValue; }
+ IRValue* getBase() { return base.usedValue; }
+ IRValue* getField() { return field.usedValue; }
};
// Terminators
@@ -215,7 +149,7 @@ struct IRReturnVal : IRReturn
{
IRUse val;
- IRInst* getVal() { return val.usedValue; }
+ IRValue* getVal() { return val.usedValue; }
};
struct IRReturnVoid : IRReturn
@@ -260,7 +194,7 @@ struct IRConditionalBranch : IRTerminatorInst
IRUse trueBlock;
IRUse falseBlock;
- IRInst* getCondition() { return condition.usedValue; }
+ IRValue* getCondition() { return condition.usedValue; }
IRBlock* getTrueBlock() { return (IRBlock*)trueBlock.usedValue; }
IRBlock* getFalseBlock() { return (IRBlock*)falseBlock.usedValue; }
};
@@ -296,12 +230,12 @@ struct IRSwizzle : IRReturn
{
IRUse base;
- IRInst* getBase() { return base.usedValue; }
+ IRValue* getBase() { return base.usedValue; }
UInt getElementCount()
{
return getArgCount() - 2;
}
- IRInst* getElementIndex(UInt index)
+ IRValue* getElementIndex(UInt index)
{
return getArg(index + 2);
}
@@ -312,43 +246,43 @@ struct IRSwizzleSet : IRReturn
IRUse base;
IRUse source;
- IRInst* getBase() { return base.usedValue; }
- IRInst* getSource() { return source.usedValue; }
+ IRValue* getBase() { return base.usedValue; }
+ IRValue* getSource() { return source.usedValue; }
UInt getElementCount()
{
return getArgCount() - 3;
}
- IRInst* getElementIndex(UInt index)
+ IRValue* getElementIndex(UInt index)
{
return getArg(index + 3);
}
};
-// "Parent" Instructions (Declarations)
-
-struct IRStructField : IRInst
+// An IR `var` instruction conceptually represents
+// a stack allocation of some memory.
+struct IRVar : IRInst
{
- IRType* getFieldType() { return (IRType*) type.usedValue; }
-
- IRStructField* getNextField() { return (IRStructField*) nextInst; }
+ PtrType* getType()
+ {
+ return (PtrType*)type.Ptr();
+ }
};
-struct IRStructDecl : IRParentInst
+struct IRGlobalVar : IRGlobalValue
{
- IRStructField* getFirstField() { return (IRStructField*) firstChild; }
- IRStructField* getLastField() { return (IRStructField*) lastChild; }
-};
+ // TODO: should contain information
+ // for use in initializing the variable
+ // (e.g., a reference to a function
+ // that is to be evaluated to provide
+ // the initial value, or a basic block
+ // that defines a DAG of constant
+ // values to use as initial values...)
-
-struct IRVar : IRInst
-{
- IRPtrType* getType()
- {
- return (IRPtrType*)type.usedValue;
- }
+ PtrType* getType() { return type.As<PtrType>(); }
};
+
// Description of an instruction to be used for global value numbering
struct IRInstKey
{
@@ -369,6 +303,13 @@ bool operator==(IRConstantKey const& left, IRConstantKey const& right);
struct SharedIRBuilder
{
+ // The parent compilation session
+ Session* session;
+ Session* getSession()
+ {
+ return session;
+ }
+
// The module that will own all of the IR
IRModule* module;
@@ -381,53 +322,36 @@ struct IRBuilder
// Shared state for all IR builders working on the same module
SharedIRBuilder* shared;
- IRModule* getModule() { return shared->module; }
-
- // The parent instruction to add children to.
- IRParentInst* parentInst;
-
- void addInst(IRParentInst* parent, IRInst* inst);
- void addInst(IRInst* inst);
-
- IRType* getBaseType(BaseType flavor);
- IRType* getBoolType();
- IRType* getVectorType(IRType* elementType, IRValue* elementCount);
- IRType* getMatrixType(
- IRType* elementType,
- IRValue* rowCount,
- IRValue* columnCount);
- IRType* getArrayType(IRType* elementType, IRValue* elementCount);
- IRType* getArrayType(IRType* elementType);
- IRType* getGenericParameterType(UInt index);
-
- IRType* getTypeType();
- IRType* getVoidType();
- IRType* getBlockType();
-
- IRType* getIntrinsicType(
- IROp op,
- UInt argCount,
- IRValue* const* args);
+ Session* getSession()
+ {
+ return shared->getSession();
+ }
- IRStructDecl* createStructType();
- IRStructField* createStructField(IRType* fieldType);
+ IRModule* getModule() { return shared->module; }
- IRType* getFuncType(
- UInt paramCount,
- IRType* const* paramTypes,
- IRType* resultType);
+ // The current function and block being inserted into
+ // (or `null` if we aren't inserting).
+ IRFunc* func = nullptr;
+ IRBlock* block = nullptr;
+ //
+ // TODO: we eventually also want an `IRInst*` for
+ // an instruction to insert before, so that we
+ // can also use the builder to insert inside
+ // an existing block.
- IRType* getPtrType(
- IRType* valueType,
- IRAddressSpace addressSpace);
+ IRFunc* getFunc() { return func; }
+ IRBlock* getBlock() { return block; }
- IRType* getPtrType(
- IRType* valueType);
+ void addInst(IRBlock* block, IRInst* inst);
+ void addInst(IRInst* inst);
IRValue* getBoolValue(bool value);
IRValue* getIntValue(IRType* type, IRIntegerValue value);
IRValue* getFloatValue(IRType* type, IRFloatingPointValue value);
+ IRValue* getDeclRefVal(
+ DeclRefBase const& declRef);
+
IRInst* emitCallInst(
IRType* type,
IRValue* func,
@@ -448,6 +372,8 @@ struct IRBuilder
IRModule* createModule();
IRFunc* createFunc();
+ IRGlobalVar* createGlobalVar(
+ IRType* valueType);
IRBlock* createBlock();
IRBlock* emitBlock();
@@ -472,12 +398,12 @@ struct IRBuilder
IRInst* emitFieldExtract(
IRType* type,
IRValue* base,
- IRStructField* field);
+ IRValue* field);
IRInst* emitFieldAddress(
IRType* type,
IRValue* basePtr,
- IRStructField* field);
+ IRValue* field);
IRInst* emitElementExtract(
IRType* type,
@@ -556,24 +482,24 @@ struct IRBuilder
IRBlock* breakBlock);
IRDecoration* addDecorationImpl(
- IRInst* inst,
+ IRValue* value,
UInt decorationSize,
IRDecorationOp op);
template<typename T>
- T* addDecoration(IRInst* inst, IRDecorationOp op)
+ T* addDecoration(IRValue* value, IRDecorationOp op)
{
- return (T*) addDecorationImpl(inst, sizeof(T), op);
+ return (T*) addDecorationImpl(value, sizeof(T), op);
}
template<typename T>
- T* addDecoration(IRInst* inst)
+ T* addDecoration(IRValue* value)
{
- return (T*) addDecorationImpl(inst, sizeof(T), IRDecorationOp(T::kDecorationOp));
+ return (T*) addDecorationImpl(value, sizeof(T), IRDecorationOp(T::kDecorationOp));
}
- IRHighLevelDeclDecoration* addHighLevelDeclDecoration(IRInst* inst, Decl* decl);
- IRLayoutDecoration* addLayoutDecoration(IRInst* inst, Layout* layout);
+ IRHighLevelDeclDecoration* addHighLevelDeclDecoration(IRValue* value, Decl* decl);
+ IRLayoutDecoration* addLayoutDecoration(IRValue* value, Layout* layout);
};
}
diff --git a/source/slang/ir.cpp b/source/slang/ir.cpp
index 3e12ef1a2..ee28227e3 100644
--- a/source/slang/ir.cpp
+++ b/source/slang/ir.cpp
@@ -40,7 +40,7 @@ namespace Slang
//
- void IRUse::init(IRValue* u, IRValue* v)
+ void IRUse::init(IRInst* u, IRValue* v)
{
user = u;
usedValue = v;
@@ -58,10 +58,17 @@ namespace Slang
IRUse* IRInst::getArgs()
{
- return &type;
+ // We assume that *all* instructions are laid out
+ // in memory such that their arguments come right
+ // after the first `sizeof(IRInst)` bytes.
+ //
+ // TODO: we probably need to be careful and make
+ // this more robust.
+
+ return (IRUse*)(this + 1);
}
- IRDecoration* IRInst::findDecorationImpl(IRDecorationOp decorationOp)
+ IRDecoration* IRValue::findDecorationImpl(IRDecorationOp decorationOp)
{
for( auto dd = firstDecoration; dd; dd = dd->next )
{
@@ -71,13 +78,28 @@ namespace Slang
return nullptr;
}
+ // IRBlock
+
+ void IRBlock::addParam(IRParam* param)
+ {
+ if (auto lp = lastParam)
+ {
+ lp->nextParam = param;
+ param->prevParam = lp;
+ }
+ else
+ {
+ firstParam = param;
+ }
+ lastParam = param;
+ }
+
// IRFunc
IRType* IRFunc::getResultType() { return getType()->getResultType(); }
UInt IRFunc::getParamCount() { return getType()->getParamCount(); }
IRType* IRFunc::getParamType(UInt index) { return getType()->getParamType(index); }
-
IRParam* IRFunc::getFirstParam()
{
auto entryBlock = getFirstBlock();
@@ -86,41 +108,20 @@ namespace Slang
return entryBlock->getFirstParam();
}
- // IRBlock
-
- IRParam* IRBlock::getFirstParam()
- {
- auto firstInst = firstChild;
- if(!firstInst) return nullptr;
-
- if(firstInst->op != kIROp_Param)
- return nullptr;
-
- return (IRParam*) firstInst;
- }
-
-
- // IRParam
-
- IRParam* IRParam::getNextParam()
+ void IRFunc::addBlock(IRBlock* block)
{
- // TODO: this is written as a search because we don't
- // currently do the careful thing and emit parameters
- // before any other members of a block.
- //
- // This should change on the emit side, instead.
+ block->parentFunc = this;
- auto next = nextInst;
-
- for (;;)
+ if (auto lb = lastBlock)
{
- if (!next) return nullptr;
-
- if(next->op == kIROp_Param)
- return (IRParam*) next;
-
- next = next->nextInst;
+ lb->nextBlock = block;
+ block->prevBlock = lb;
+ }
+ else
+ {
+ firstBlock = block;
}
+ lastBlock = block;
}
//
@@ -156,27 +157,27 @@ namespace Slang
//
// Add an instruction to a specific parent
- void IRBuilder::addInst(IRParentInst* parent, IRInst* inst)
+ void IRBuilder::addInst(IRBlock* block, IRInst* inst)
{
- inst->parent = parent;
+ inst->parentBlock = block;
- if (!parent->firstChild)
+ if (!block->firstInst)
{
inst->prevInst = nullptr;
inst->nextInst = nullptr;
- parent->firstChild = inst;
- parent->lastChild = inst;
+ block->firstInst = inst;
+ block->lastInst = inst;
}
else
{
- auto prev = parent->lastChild;
+ auto prev = block->lastInst;
inst->prevInst = prev;
inst->nextInst = nullptr;
prev->nextInst = inst;
- parent->lastChild = inst;
+ block->lastInst = inst;
}
}
@@ -184,19 +185,42 @@ namespace Slang
void IRBuilder::addInst(
IRInst* inst)
{
- auto parent = parentInst;
+ auto parent = block;
if (!parent)
return;
addInst(parent, inst);
}
+ static IRValue* createValueImpl(
+ IRBuilder* builder,
+ UInt size,
+ IROp op,
+ IRType* type)
+ {
+ IRValue* value = (IRValue*) malloc(size);
+ memset(value, 0, size);
+ value->op = op;
+ value->type = type;
+ return value;
+ }
+
+ template<typename T>
+ static T* createValue(
+ IRBuilder* builder,
+ IROp op,
+ IRType* type)
+ {
+ return (T*) createValueImpl(builder, sizeof(T), op, type);
+ }
+
+
// Create an IR instruction/value and initialize it.
//
// In this case `argCount` and `args` represnt the
// arguments *after* the type (which is a mandatory
// argument for all instructions).
- static IRValue* createInstImpl(
+ static IRInst* createInstImpl(
IRBuilder* builder,
UInt size,
IROp op,
@@ -206,26 +230,17 @@ namespace Slang
UInt varArgCount = 0,
IRValue* const* varArgs = nullptr)
{
- IRValue* inst = (IRInst*) malloc(size);
+ IRInst* inst = (IRInst*) malloc(size);
memset(inst, 0, size);
auto module = builder->getModule();
- if (!module || (type && type->op == kIROp_VoidType))
- {
- // Can't or shouldn't assign an ID to this op
- }
- else
- {
- inst->id = ++module->idCounter;
- }
- inst->argCount = fixedArgCount + varArgCount + 1;
+ inst->argCount = fixedArgCount + varArgCount;
inst->op = op;
- auto operand = inst->getArgs();
+ inst->type = type;
- operand->init(inst, type);
- operand++;
+ auto operand = inst->getArgs();
for( UInt aa = 0; aa < fixedArgCount; ++aa )
{
@@ -394,7 +409,7 @@ namespace Slang
bool operator==(IRInstKey const& left, IRInstKey const& right)
{
if(left.inst->op != right.inst->op) return false;
- if(left.inst->parent != right.inst->parent) return false;
+ if(left.inst->parentBlock != right.inst->parentBlock) return false;
if(left.inst->argCount != right.inst->argCount) return false;
auto argCount = left.inst->argCount;
@@ -412,7 +427,7 @@ namespace Slang
int IRInstKey::GetHashCode()
{
auto code = Slang::GetHashCode(inst->op);
- code = combineHash(code, Slang::GetHashCode(inst->parent));
+ code = combineHash(code, Slang::GetHashCode(inst->parentBlock));
code = combineHash(code, Slang::GetHashCode(inst->argCount));
auto argCount = inst->argCount;
@@ -424,233 +439,21 @@ namespace Slang
return code;
}
- static IRParentInst* joinParentInstsForInsertion(
- IRParentInst* left,
- IRParentInst* right)
- {
- // Are they the same? Easy.
- if(left == right) return left;
-
- // Have we already failed to find a location? Then bail.
- if(!left) return nullptr;
- if(!right) return nullptr;
-
- // Is one inst a parent of the other? Pick the child.
- for( auto ll = left; ll; ll = ll->parent )
- {
- // Did we find the right node in the parent list of the left?
- if(ll == right) return left;
- }
- for( auto rr = right; rr; rr = rr->parent )
- {
- // Did we find the left node in the parent list of the right?
- if(rr == left) return right;
- }
-
- // Seems like they are unrelated, so we should play it safe
- return nullptr;
- }
-
-
- static IRInst* findOrEmitInstImpl(
- IRBuilder* builder,
- UInt size,
- IROp op,
- IRType* type,
- UInt fixedArgCount,
- IRValue* const* fixedArgs,
- UInt varArgCount = 0,
- IRValue* const* varArgs = nullptr)
- {
- // First, we need to pick a good insertion point
- // for the instruction, which we do by looking
- // at its operands.
- //
-
- IRParentInst* parent = builder->shared->module;
- if( type )
- {
- parent = joinParentInstsForInsertion(parent, type->parent);
- }
- for( UInt aa = 0; aa < fixedArgCount; ++aa )
- {
- auto arg = fixedArgs[aa];
- parent = joinParentInstsForInsertion(parent, arg->parent);
- }
- for( UInt aa = 0; aa < varArgCount; ++aa )
- {
- auto arg = varArgs[aa];
- parent = joinParentInstsForInsertion(parent, arg->parent);
- }
-
- // If we failed to find a good insertion point, then insert locally.
- if( !parent )
- {
- parent = builder->parentInst;
- }
-
- if( parent->op == kIROp_Func )
- {
- // We are trying to insert into a function, and we should really
- // be inserting into its entry block.
- assert(parent->firstChild);
- parent = (IRBlock*) ((IRFunc*) parent)->firstChild;
- }
-
- // We now know where we want to insert, but there might
- // already be an equivalent instruction in that block.
- //
- // We will check for such an instruction in a slightly hacky
- // way: we will construct a temporary instruction and
- // then use it to look up in a cache of instructions.
-
- IRInst* keyInst = createInstImpl(builder, size, op, type, fixedArgCount, fixedArgs, varArgCount, varArgs);
- keyInst->parent = parent;
-
- IRInstKey key;
- key.inst = keyInst;
-
- IRInst* inst = nullptr;
- if( builder->shared->globalValueNumberingMap.TryGetValue(key, inst) )
- {
- // We found a match, so just use that.
-
- free(keyInst);
- return inst;
- }
-
- // No match, so use our "key" instruction for real.
- inst = keyInst;
-
- builder->shared->globalValueNumberingMap.Add(key, inst);
-
- keyInst->parent = nullptr;
- builder->addInst(parent, inst);
-
- return inst;
- }
-
- template<typename T>
- static T* findOrEmitInst(
- IRBuilder* builder,
- IROp op,
- IRType* type,
- UInt argCount,
- IRValue* const* args)
- {
- return (T*) findOrEmitInstImpl(
- builder,
- sizeof(T),
- op,
- type,
- argCount,
- args);
- }
-
- template<typename T>
- static T* findOrEmitInst(
- IRBuilder* builder,
- IROp op,
- IRType* type,
- UInt fixedArgCount,
- IRValue* const* fixedArgs,
- UInt varArgCount,
- IRValue* const* varArgs)
- {
- return (T*) findOrEmitInstImpl(
- builder,
- sizeof(T) + varArgCount * sizeof(IRUse),
- op,
- type,
- fixedArgCount,
- fixedArgs,
- varArgCount,
- varArgs);
- }
-
- template<typename T>
- static T* findOrEmitInst(
- IRBuilder* builder,
- IROp op,
- IRType* type)
- {
- return (T*) findOrEmitInstImpl(
- builder,
- sizeof(T),
- op,
- type,
- 0,
- nullptr);
- }
-
- template<typename T>
- static T* findOrEmitInst(
- IRBuilder* builder,
- IROp op,
- IRType* type,
- IRInst* arg)
- {
- return (T*) findOrEmitInstImpl(
- builder,
- sizeof(T),
- op,
- type,
- 1,
- &arg);
- }
-
- template<typename T>
- static T* findOrEmitInst(
- IRBuilder* builder,
- IROp op,
- IRType* type,
- IRInst* arg1,
- IRInst* arg2)
- {
- IRInst* args[] = { arg1, arg2 };
- return (T*) findOrEmitInstImpl(
- builder,
- sizeof(T),
- op,
- type,
- 2,
- &args[0]);
- }
-
- template<typename T>
- static T* findOrEmitInst(
- IRBuilder* builder,
- IROp op,
- IRType* type,
- IRInst* arg1,
- IRInst* arg2,
- IRInst* arg3)
- {
- IRInst* args[] = { arg1, arg2, arg3 };
- return (T*) findOrEmitInstImpl(
- builder,
- sizeof(T),
- op,
- type,
- 3,
- &args[0]);
- }
-
//
bool operator==(IRConstantKey const& left, IRConstantKey const& right)
{
- if(left.inst->op != right.inst->op) return false;
- if(left.inst->type.usedValue != right.inst->type.usedValue) return false;
- if(left.inst->u.ptrData[0] != right.inst->u.ptrData[0]) return false;
- if(left.inst->u.ptrData[1] != right.inst->u.ptrData[1]) return false;
+ if(left.inst->op != right.inst->op) return false;
+ if(left.inst->type != right.inst->type) return false;
+ if(left.inst->u.ptrData[0] != right.inst->u.ptrData[0]) return false;
+ if(left.inst->u.ptrData[1] != right.inst->u.ptrData[1]) return false;
return true;
}
int IRConstantKey::GetHashCode()
{
auto code = Slang::GetHashCode(inst->op);
- code = combineHash(code, Slang::GetHashCode(inst->type.usedValue));
+ code = combineHash(code, Slang::GetHashCode(inst->type));
code = combineHash(code, Slang::GetHashCode(inst->u.ptrData[0]));
code = combineHash(code, Slang::GetHashCode(inst->u.ptrData[1]));
return code;
@@ -668,22 +471,20 @@ namespace Slang
// at its operands.
//
- IRParentInst* parent = builder->shared->module;
-
IRConstant keyInst;
memset(&keyInst, 0, sizeof(keyInst));
keyInst.op = op;
- keyInst.type.usedValue = type;
+ keyInst.type = type;
memcpy(&keyInst.u, value, valueSize);
IRConstantKey key;
key.inst = &keyInst;
- IRConstant* inst = nullptr;
- if( builder->shared->constantMap.TryGetValue(key, inst) )
+ IRConstant* irValue = nullptr;
+ if( builder->shared->constantMap.TryGetValue(key, irValue) )
{
// We found a match, so just use that.
- return inst;
+ return irValue;
}
// We now know where we want to insert, but there might
@@ -693,221 +494,25 @@ namespace Slang
// way: we will construct a temporary instruction and
// then use it to look up in a cache of instructions.
- inst = createInst<IRConstant>(builder, op, type);
- memcpy(&inst->u, value, valueSize);
+ irValue = createInst<IRConstant>(builder, op, type);
+ memcpy(&irValue->u, value, valueSize);
- key.inst = inst;
- builder->shared->constantMap.Add(key, inst);
+ key.inst = irValue;
+ builder->shared->constantMap.Add(key, irValue);
- builder->addInst(parent, inst);
-
- return inst;
+ return irValue;
}
//
- static IRType* getBaseTypeImpl(IRBuilder* builder, IROp op)
- {
- auto inst = findOrEmitInst<IRType>(
- builder,
- op,
- builder->getTypeType());
- return inst;
- }
-
- IRType* IRBuilder::getBaseType(BaseType flavor)
- {
- switch( flavor )
- {
- case BaseType::Void: return getVoidType();
-
- case BaseType::Bool: return getBaseTypeImpl(this, kIROp_BoolType);
- case BaseType::Float: return getBaseTypeImpl(this, kIROp_Float32Type);
- case BaseType::Int: return getBaseTypeImpl(this, kIROp_Int32Type);
- case BaseType::UInt: return getBaseTypeImpl(this, kIROp_UInt32Type);
-
- default:
- SLANG_UNEXPECTED("unhandled base type");
- return nullptr;
- }
- }
-
- IRType* IRBuilder::getBoolType()
- {
- return getBaseType(BaseType::Bool);
- }
-
- IRType* IRBuilder::getVectorType(IRType* elementType, IRValue* elementCount)
- {
- return findOrEmitInst<IRVectorType>(
- this,
- kIROp_VectorType,
- getTypeType(),
- elementType,
- elementCount);
- }
-
- IRType* IRBuilder::getMatrixType(
- IRType* elementType,
- IRValue* rowCount,
- IRValue* columnCount)
- {
- return findOrEmitInst<IRMatrixType>(
- this,
- kIROp_MatrixType,
- getTypeType(),
- elementType,
- rowCount,
- columnCount);
- }
-
- IRType* IRBuilder::getArrayType(IRType* elementType, IRValue* elementCount)
- {
- // The client requests an unsized array by passing `nullptr` for
- // the `elementCount`.
- //
- // We currently encode an unsized array as an ordinary array with
- // zero elements. TODO: carefully consider this choice.
- if (!elementCount)
- {
- elementCount = getIntValue(
- getBaseType(BaseType::Int),
- 0);
- }
-
- return findOrEmitInst<IRArrayType>(
- this,
- kIROp_arrayType,
- getTypeType(),
- elementType,
- elementCount);
- }
-
- IRType* IRBuilder::getArrayType(IRType* elementType)
- {
- return getArrayType(elementType, nullptr);
- }
-
- IRType* IRBuilder::getGenericParameterType(UInt index)
- {
- auto indexVal = getIntValue(getBaseType(BaseType::Int), index);
-
- return findOrEmitInst<IRGenericParameterType>(
- this,
- kIROp_GenericParameterType,
- getTypeType(),
- indexVal);
-
- }
-
-
- IRType* IRBuilder::getTypeType()
- {
- return findOrEmitInst<IRType>(
- this,
- kIROp_TypeType,
- nullptr);
- }
-
- IRType* IRBuilder::getVoidType()
- {
- return findOrEmitInst<IRType>(
- this,
- kIROp_VoidType,
- getTypeType());
- }
-
- IRType* IRBuilder::getBlockType()
- {
- return findOrEmitInst<IRType>(
- this,
- kIROp_BlockType,
- getTypeType());
- }
-
- IRType* IRBuilder::getIntrinsicType(
- IROp op,
- UInt argCount,
- IRValue* const* args)
- {
- return findOrEmitInst<IRType>(
- this,
- op,
- getTypeType(),
- 0,
- nullptr,
- argCount,
- args);
- }
-
-
- IRStructDecl* IRBuilder::createStructType()
- {
- return createInst<IRStructDecl>(
- this,
- kIROp_StructType,
- getTypeType());
- }
-
- IRStructField* IRBuilder::createStructField(IRType* fieldType)
- {
- return createInst<IRStructField>(
- this,
- kIROp_StructField,
- fieldType);
- }
-
-
- IRType* IRBuilder::getFuncType(
- UInt paramCount,
- IRType* const* paramTypes,
- IRType* resultType)
- {
- // TODO: need to unique things here!
- auto inst = createInstWithTrailingArgs<IRFuncType>(
- this,
- kIROp_FuncType,
- getTypeType(),
- 1,
- (IRValue* const*) &resultType,
- paramCount,
- (IRValue* const*) paramTypes);
- addInst(inst);
- return inst;
- }
-
- IRType* IRBuilder::getPtrType(
- IRType* valueType,
- IRAddressSpace addressSpace)
- {
- auto uintType = getBaseType(BaseType::UInt);
- auto irAddressSpace = getIntValue(uintType, addressSpace);
-
- auto inst = findOrEmitInst<IRPtrType>(
- this,
- kIROp_PtrType,
- getTypeType(),
- valueType,
- irAddressSpace);
- return inst;
- }
-
-
- IRType* IRBuilder::getPtrType(
- IRType* valueType)
- {
- return getPtrType(valueType, kIRAddressSpace_Default);
- }
-
-
IRValue* IRBuilder::getBoolValue(bool inValue)
{
IRIntegerValue value = inValue;
return findOrEmitConstant(
this,
kIROp_boolConst,
- getBoolType(),
+ getSession()->getBoolType(),
sizeof(value),
&value);
}
@@ -932,6 +537,18 @@ namespace Slang
&value);
}
+ IRValue* IRBuilder::getDeclRefVal(
+ DeclRefBase const& declRef)
+ {
+ // TODO: we should cache these...
+ auto irValue = createInst<IRDeclRef>(
+ this,
+ kIROp_decl_ref,
+ nullptr);
+ irValue->declRef = declRef;
+ return irValue;
+ }
+
IRInst* IRBuilder::emitCallInst(
IRType* type,
IRValue* func,
@@ -983,52 +600,69 @@ namespace Slang
IRModule* IRBuilder::createModule()
{
- return createInst<IRModule>(
- this,
- kIROp_Module,
- nullptr);
+ return new IRModule();
}
IRFunc* IRBuilder::createFunc()
{
- return createInst<IRFunc>(
+ return createValue<IRFunc>(
this,
kIROp_Func,
nullptr);
}
+ IRGlobalVar* IRBuilder::createGlobalVar(
+ IRType* valueType)
+ {
+ auto ptrType = getSession()->getPtrType(valueType);
+ return createValue<IRGlobalVar>(
+ this,
+ kIROp_global_var,
+ ptrType);
+ }
+
IRBlock* IRBuilder::createBlock()
{
- return createInst<IRBlock>(
+ return createValue<IRBlock>(
this,
kIROp_Block,
- getBlockType());
+ getSession()->getIRBasicBlockType());
}
IRBlock* IRBuilder::emitBlock()
{
- auto inst = createBlock();
- addInst(inst);
- return inst;
+ auto bb = createBlock();
+
+ auto f = this->func;
+ if (f)
+ {
+ f->addBlock(bb);
+ this->block = bb;
+ }
+ return bb;
}
IRParam* IRBuilder::emitParam(
IRType* type)
{
- auto inst = createInst<IRParam>(
+ auto param = createValue<IRParam>(
this,
kIROp_Param,
type);
- addInst(inst);
- return inst;
+
+ if (auto bb = block)
+ {
+ bb->addParam(param);
+ }
+ return param;
}
IRVar* IRBuilder::emitVar(
IRType* type,
IRAddressSpace addressSpace)
{
- auto allocatedType = getPtrType(type, addressSpace);
+ auto allocatedType = getSession()->getPtrType(type);
auto inst = createInst<IRVar>(
this,
kIROp_Var,
@@ -1047,14 +681,14 @@ namespace Slang
IRInst* IRBuilder::emitLoad(
IRValue* ptr)
{
- auto ptrType = ptr->getType();
- if( ptrType->op != kIROp_PtrType )
+ auto ptrType = ptr->getType()->As<PtrType>();
+ if( !ptrType )
{
// Bad!
return nullptr;
}
- auto valueType = ((IRPtrType*) ptrType)->getValueType();
+ auto valueType = ptrType->getValueType();
auto inst = createInst<IRLoad>(
this,
@@ -1070,11 +704,10 @@ namespace Slang
IRValue* dstPtr,
IRValue* srcVal)
{
- auto type = getVoidType();
auto inst = createInst<IRStore>(
this,
kIROp_Store,
- type,
+ nullptr,
dstPtr,
srcVal);
@@ -1085,7 +718,7 @@ namespace Slang
IRInst* IRBuilder::emitFieldExtract(
IRType* type,
IRValue* base,
- IRStructField* field)
+ IRValue* field)
{
auto inst = createInst<IRFieldExtract>(
this,
@@ -1101,7 +734,7 @@ namespace Slang
IRInst* IRBuilder::emitFieldAddress(
IRType* type,
IRValue* base,
- IRStructField* field)
+ IRValue* field)
{
auto inst = createInst<IRFieldAddress>(
this,
@@ -1170,7 +803,7 @@ namespace Slang
UInt elementCount,
UInt const* elementIndices)
{
- auto intType = getBaseType(BaseType::Int);
+ auto intType = getSession()->getBuiltinType(BaseType::Int);
IRValue* irElementIndices[4];
for (UInt ii = 0; ii < elementCount; ++ii)
@@ -1212,7 +845,7 @@ namespace Slang
UInt elementCount,
UInt const* elementIndices)
{
- auto intType = getBaseType(BaseType::Int);
+ auto intType = getSession()->getBuiltinType(BaseType::Int);
IRValue* irElementIndices[4];
for (UInt ii = 0; ii < elementCount; ++ii)
@@ -1229,7 +862,7 @@ namespace Slang
auto inst = createInst<IRReturnVal>(
this,
kIROp_ReturnVal,
- getVoidType(),
+ nullptr,
val);
addInst(inst);
return inst;
@@ -1240,7 +873,7 @@ namespace Slang
auto inst = createInst<IRReturnVoid>(
this,
kIROp_ReturnVoid,
- getVoidType());
+ nullptr);
addInst(inst);
return inst;
}
@@ -1251,7 +884,7 @@ namespace Slang
auto inst = createInst<IRUnconditionalBranch>(
this,
kIROp_unconditionalBranch,
- getVoidType(),
+ nullptr,
block);
addInst(inst);
return inst;
@@ -1263,7 +896,7 @@ namespace Slang
auto inst = createInst<IRBreak>(
this,
kIROp_break,
- getVoidType(),
+ nullptr,
target);
addInst(inst);
return inst;
@@ -1275,7 +908,7 @@ namespace Slang
auto inst = createInst<IRContinue>(
this,
kIROp_continue,
- getVoidType(),
+ nullptr,
target);
addInst(inst);
return inst;
@@ -1286,13 +919,13 @@ namespace Slang
IRBlock* breakBlock,
IRBlock* continueBlock)
{
- IRInst* args[] = { target, breakBlock, continueBlock };
+ IRValue* args[] = { target, breakBlock, continueBlock };
UInt argCount = sizeof(args) / sizeof(args[0]);
auto inst = createInst<IRLoop>(
this,
kIROp_loop,
- getVoidType(),
+ nullptr,
argCount,
args);
addInst(inst);
@@ -1304,13 +937,13 @@ namespace Slang
IRBlock* trueBlock,
IRBlock* falseBlock)
{
- IRInst* args[] = { val, trueBlock, falseBlock };
+ IRValue* args[] = { val, trueBlock, falseBlock };
UInt argCount = sizeof(args) / sizeof(args[0]);
auto inst = createInst<IRConditionalBranch>(
this,
kIROp_conditionalBranch,
- getVoidType(),
+ nullptr,
argCount,
args);
addInst(inst);
@@ -1322,13 +955,13 @@ namespace Slang
IRBlock* trueBlock,
IRBlock* afterBlock)
{
- IRInst* args[] = { val, trueBlock, afterBlock };
+ IRValue* args[] = { val, trueBlock, afterBlock };
UInt argCount = sizeof(args) / sizeof(args[0]);
auto inst = createInst<IRIf>(
this,
kIROp_if,
- getVoidType(),
+ nullptr,
argCount,
args);
addInst(inst);
@@ -1341,13 +974,13 @@ namespace Slang
IRBlock* falseBlock,
IRBlock* afterBlock)
{
- IRInst* args[] = { val, trueBlock, falseBlock, afterBlock };
+ IRValue* args[] = { val, trueBlock, falseBlock, afterBlock };
UInt argCount = sizeof(args) / sizeof(args[0]);
auto inst = createInst<IRIfElse>(
this,
kIROp_ifElse,
- getVoidType(),
+ nullptr,
argCount,
args);
addInst(inst);
@@ -1359,13 +992,13 @@ namespace Slang
IRBlock* bodyBlock,
IRBlock* breakBlock)
{
- IRInst* args[] = { val, bodyBlock, breakBlock };
+ IRValue* args[] = { val, bodyBlock, breakBlock };
UInt argCount = sizeof(args) / sizeof(args[0]);
auto inst = createInst<IRLoopTest>(
this,
kIROp_loopTest,
- getVoidType(),
+ nullptr,
argCount,
args);
addInst(inst);
@@ -1373,7 +1006,7 @@ namespace Slang
}
IRDecoration* IRBuilder::addDecorationImpl(
- IRInst* inst,
+ IRValue* inst,
UInt decorationSize,
IRDecorationOp op)
{
@@ -1388,14 +1021,14 @@ namespace Slang
return decoration;
}
- IRHighLevelDeclDecoration* IRBuilder::addHighLevelDeclDecoration(IRInst* inst, Decl* decl)
+ IRHighLevelDeclDecoration* IRBuilder::addHighLevelDeclDecoration(IRValue* inst, Decl* decl)
{
auto decoration = addDecoration<IRHighLevelDeclDecoration>(inst, kIRDecorationOp_HighLevelDecl);
decoration->decl = decl;
return decoration;
}
- IRLayoutDecoration* IRBuilder::addLayoutDecoration(IRInst* inst, Layout* layout)
+ IRLayoutDecoration* IRBuilder::addLayoutDecoration(IRValue* inst, Layout* layout)
{
auto decoration = addDecoration<IRLayoutDecoration>(inst);
decoration->layout = layout;
@@ -1409,6 +1042,9 @@ namespace Slang
{
StringBuilder* builder;
int indent;
+
+ UInt idCounter = 1;
+ Dictionary<IRValue*, UInt> mapValueToID;
};
static void dump(
@@ -1454,27 +1090,59 @@ namespace Slang
}
}
+ bool opHasResult(IRValue* inst);
+
+ static UInt getID(
+ IRDumpContext* context,
+ IRValue* value)
+ {
+ UInt id = 0;
+ if (context->mapValueToID.TryGetValue(value, id))
+ return id;
+
+ if (opHasResult(value))
+ {
+ id = context->idCounter++;
+ }
+
+ context->mapValueToID.Add(value, id);
+ return id;
+ }
+
static void dumpID(
IRDumpContext* context,
- IRInst* inst)
+ IRValue* inst)
{
if (!inst)
{
dump(context, "<null>");
+ return;
}
- else if( auto mangled = inst->findDecoration<IRMangledNameDecoration>() )
- {
- dump(context, "@");
- dump(context, mangled->mangledName.Buffer());
- }
- else if(inst->id)
- {
- dump(context, "%");
- dump(context, (UInt) inst->id);
- }
- else
+
+ switch(inst->op)
{
- dump(context, "_");
+ case kIROp_Func:
+ {
+ auto irFunc = (IRFunc*) inst;
+ dump(context, "@");
+ dump(context, irFunc->mangledName.Buffer());
+ }
+ break;
+
+ default:
+ {
+ UInt id = getID(context, inst);
+ if (id)
+ {
+ dump(context, "%");
+ dump(context, id);
+ }
+ else
+ {
+ dump(context, "_");
+ }
+ }
+ break;
}
}
@@ -1484,8 +1152,9 @@ namespace Slang
static void dumpOperand(
IRDumpContext* context,
- IRInst* inst)
+ IRValue* inst)
{
+ // TODO: we should have a dedicated value for the `undef` case
if (!inst)
{
dump(context, "undef");
@@ -1514,22 +1183,90 @@ namespace Slang
break;
}
- auto type = inst->getType();
- if (type)
+ dumpID(context, inst);
+ }
+
+ static void dump(
+ IRDumpContext* context,
+ Name* name)
+ {
+ dump(context, getText(name).Buffer());
+ }
+
+
+ static void dumpDeclRef(
+ IRDumpContext* context,
+ DeclRef<Decl> const& declRef);
+
+ static void dumpVal(
+ IRDumpContext* context,
+ Val* val)
+ {
+ if(auto type = dynamic_cast<Type*>(val))
{
- switch (type->op)
- {
- case kIROp_TypeType:
- dumpType(context, (IRType*)inst);
- return;
+ dumpType(context, type);
+ }
+ else if(auto constIntVal = dynamic_cast<ConstantIntVal*>(val))
+ {
+ dump(context, constIntVal->value);
+ }
+ else if(auto genericParamVal = dynamic_cast<GenericParamIntVal*>(val))
+ {
+ dumpDeclRef(context, genericParamVal->declRef);
+ }
+ else
+ {
+ dump(context, "???");
+ }
+ }
- default:
- break;
- }
+ static void dumpDeclRef(
+ IRDumpContext* context,
+ DeclRef<Decl> const& declRef)
+ {
+ auto decl = declRef.getDecl();
+
+ auto parentDeclRef = declRef.GetParent();
+ auto genericParentDeclRef = parentDeclRef.As<GenericDecl>();
+ if(genericParentDeclRef)
+ {
+ parentDeclRef = genericParentDeclRef.GetParent();
}
+ if(parentDeclRef.As<ModuleDecl>())
+ {
+ parentDeclRef = DeclRef<ContainerDecl>();
+ }
- dumpID(context, inst);
+ if(parentDeclRef)
+ {
+ dumpDeclRef(context, parentDeclRef);
+ dump(context, ".");
+ }
+ dump(context, decl->getName());
+
+ if(genericParentDeclRef)
+ {
+ auto subst = declRef.substitutions;
+ if( !subst || subst->genericDecl != genericParentDeclRef.getDecl() )
+ {
+ // No actual substitutions in place here
+ dump(context, "<>");
+ }
+ else
+ {
+ auto args = subst->args;
+ bool first = true;
+ dump(context, "<");
+ for(auto aa : args)
+ {
+ if(!first) dump(context, ",");
+ dumpVal(context, aa);
+ first = false;
+ }
+ dump(context, ">");
+ }
+ }
}
static void dumpType(
@@ -1538,10 +1275,43 @@ namespace Slang
{
if (!type)
{
- dumpID(context, type);
+ dump(context, "_");
return;
}
+ if(auto funcType = type->As<FuncType>())
+ {
+ UInt paramCount = funcType->getParamCount();
+ dump(context, "(");
+ for( UInt pp = 0; pp < paramCount; ++pp )
+ {
+ if(pp != 0) dump(context, ", ");
+ dumpType(context, funcType->getParamType(pp));
+ }
+ dump(context, ") -> ");
+ dumpType(context, funcType->getResultType());
+ }
+ else if(auto arrayType = type->As<ArrayExpressionType>())
+ {
+ dumpType(context, arrayType->baseType);
+ dump(context, "[");
+ if(auto elementCount = arrayType->ArrayLength)
+ {
+ dumpVal(context, elementCount);
+ }
+ dump(context, "]");
+ }
+ else if(auto declRefType = type->As<DeclRefType>())
+ {
+ dumpDeclRef(context, declRefType->declRef);
+ }
+ else
+ {
+ // Need a default case here
+ dump(context, "???");
+ }
+
+#if 0
auto op = type->op;
auto opInfo = kIROpInfos[op];
@@ -1551,21 +1321,6 @@ namespace Slang
dumpID(context, type);
break;
- case kIROp_FuncType:
- {
- auto funcType = (IRFuncType*) type;
- UInt paramCount = funcType->getParamCount();
- dump(context, "(");
- for( UInt pp = 0; pp < paramCount; ++pp )
- {
- if(pp != 0) dump(context, ", ");
- dumpType(context, funcType->getParamType(pp));
- }
- dump(context, ") -> ");
- dumpType(context, funcType->getResultType());
- }
- break;
-
default:
{
dump(context, opInfo.name);
@@ -1585,6 +1340,7 @@ namespace Slang
}
break;
}
+#endif
}
static void dumpInstTypeClause(
@@ -1602,32 +1358,68 @@ namespace Slang
static void dumpChildrenRaw(
IRDumpContext* context,
- IRParentInst* parent)
+ IRBlock* block)
{
- for (auto ii = parent->firstChild; ii; ii = ii->nextInst)
+ for (auto ii = block->firstInst; ii; ii = ii->nextInst)
{
dumpInst(context, ii);
}
}
- static void dumpChildren(
+ static void dumpBlock(
IRDumpContext* context,
- IRInst* inst)
+ IRBlock* block)
{
- auto op = inst->op;
- auto opInfo = &kIROpInfos[op];
- if (opInfo->flags & kIROpFlag_Parent)
+ context->indent--;
+ dump(context, "block ");
+ dumpID(context, block);
+
+ if( block->getFirstParam() )
{
- dumpIndent(context);
- dump(context, "{\n");
- context->indent++;
- auto parent = (IRParentInst*)inst;
- dumpChildrenRaw(context, parent);
- context->indent--;
- dumpIndent(context);
- dump(context, "}\n");
+ dump(context, "(\n");
+ context->indent += 2;
+ for (auto pp = block->getFirstParam(); pp; pp = pp->getNextParam())
+ {
+ if (pp != block->getFirstParam())
+ dump(context, ",\n");
+
+ dumpIndent(context);
+ dump(context, "param ");
+ dumpID(context, pp);
+ dumpInstTypeClause(context, pp->getType());
+ }
+ context->indent -= 2;
+ dump(context, ")");
+ }
+ dump(context, ":\n");
+ context->indent++;
+
+ dumpChildrenRaw(context, block);
+ }
+
+ static void dumpChildrenRaw(
+ IRDumpContext* context,
+ IRFunc* func)
+ {
+ for (auto bb = func->getFirstBlock(); bb; bb = bb->getNextBlock())
+ {
+ dumpBlock(context, bb);
}
}
+
+ static void dumpChildren(
+ IRDumpContext* context,
+ IRFunc* func)
+ {
+ dumpIndent(context);
+ dump(context, "{\n");
+ context->indent++;
+ dumpChildrenRaw(context, func);
+ context->indent--;
+ dumpIndent(context);
+ dump(context, "}\n");
+ }
+
static void dumpInst(
IRDumpContext* context,
IRInst* inst)
@@ -1646,6 +1438,7 @@ namespace Slang
//
switch (op)
{
+#if 0
case kIROp_Module:
dumpIndent(context);
dump(context, "module\n");
@@ -1717,11 +1510,13 @@ namespace Slang
dumpChildrenRaw(context, block);
}
return;
+#endif
default:
break;
}
+#if 0
// We also want to special-case based on the *type*
// of the instruction
auto type = inst->getType();
@@ -1738,38 +1533,42 @@ namespace Slang
return;
}
}
+#endif
// Okay, we have a seemingly "ordinary" op now
dumpIndent(context);
auto opInfo = &kIROpInfos[op];
+ auto type = inst->getType();
- if (type && type->op == kIROp_TypeType)
- {
- dump(context, "type ");
- dumpID(context, inst);
- dump(context, "\t= ");
- }
- else if (type && type->op == kIROp_VoidType)
+ if (!type)
{
+ // No result, okay...
}
else
{
- dump(context, "let ");
- dumpID(context, inst);
- dumpInstTypeClause(context, type);
- dump(context, "\t= ");
+ auto basicType = type->As<BasicExpressionType>();
+ if (basicType && basicType->baseType == BaseType::Void)
+ {
+ // No result, okay...
+ }
+ else
+ {
+ dump(context, "let ");
+ dumpID(context, inst);
+ dumpInstTypeClause(context, type);
+ dump(context, "\t= ");
+ }
}
-
dump(context, opInfo->name);
uint32_t argCount = inst->argCount;
dump(context, "(");
- for (uint32_t ii = 1; ii < argCount; ++ii)
+ for (uint32_t ii = 0; ii < argCount; ++ii)
{
- if (ii != 1)
+ if (ii != 0)
dump(context, ", ");
auto argVal = inst->getArgs()[ii].usedValue;
@@ -1779,10 +1578,78 @@ namespace Slang
dump(context, ")");
dump(context, "\n");
+ }
+
+ void dumpIRFunc(
+ IRDumpContext* context,
+ IRFunc* func)
+ {
+ dump(context, "\n");
+ dumpIndent(context);
+ dump(context, "ir_func ");
+ dumpID(context, func);
+ dumpInstTypeClause(context, func->getType());
+ dump(context, "\n");
+
+ dumpIndent(context);
+ dump(context, "{\n");
+ context->indent++;
- // The instruction might have children,
- // so we need to handle those here
- dumpChildren(context, inst);
+ for (auto bb = func->getFirstBlock(); bb; bb = bb->getNextBlock())
+ {
+ if (bb != func->getFirstBlock())
+ dump(context, "\n");
+ dumpBlock(context, bb);
+ }
+
+ context->indent--;
+ dump(context, "}\n");
+ }
+
+ void dumpIRGlobalVar(
+ IRDumpContext* context,
+ IRGlobalVar* var)
+ {
+ dump(context, "\n");
+ dumpIndent(context);
+ dump(context, "ir_global_var ");
+ dumpID(context, var);
+ dumpInstTypeClause(context, var->getType());
+
+ // TODO: deal with the case where a global
+ // might have embedded initialization logic.
+
+ dump(context, ";\n");
+ }
+
+ void dumpIRGlobalValue(
+ IRDumpContext* context,
+ IRGlobalValue* value)
+ {
+ switch (value->op)
+ {
+ case kIROp_Func:
+ dumpIRFunc(context, (IRFunc*)value);
+ break;
+
+ case kIROp_global_var:
+ dumpIRGlobalVar(context, (IRGlobalVar*)value);
+ break;
+
+ default:
+ dump(context, "???\n");
+ break;
+ }
+ }
+
+ void dumpIRModule(
+ IRDumpContext* context,
+ IRModule* module)
+ {
+ for (auto gv : module->globalValues)
+ {
+ dumpIRGlobalValue(context, gv);
+ }
}
void printSlangIRAssembly(StringBuilder& builder, IRModule* module)
@@ -1791,7 +1658,7 @@ namespace Slang
context.builder = &builder;
context.indent = 0;
- dumpChildrenRaw(&context, module);
+ dumpIRModule(&context, module);
}
String getSlangIRAssembly(IRModule* module)
diff --git a/source/slang/ir.h b/source/slang/ir.h
index 9c3124478..7e4dc8f83 100644
--- a/source/slang/ir.h
+++ b/source/slang/ir.h
@@ -11,36 +11,14 @@
namespace Slang {
-// TODO(tfoley): We should ditch this enumeration
-// and just use the IR opcodes that represent these
-// types directly. The one major complication there
-// is that the order of the enum values currently
-// matters, since it determines promotion rank.
-// We either need to keep that restriction, or
-// look up promotion rank by some other means.
-//
-enum class BaseType
-{
- // Note(tfoley): These are ordered in terms of promotion rank, so be vareful when messing with this
-
- Void = 0,
- Bool,
- Int,
- UInt,
- UInt64,
- Half,
- Float,
- Double,
-};
-
+class FuncType;
+class Layout;
+class Type;
-class Layout;
-
-struct IRFunc;
-struct IRInst;
-struct IRModule;
-struct IRParentInst;
-struct IRType;
+struct IRFunc;
+struct IRInst;
+struct IRModule;
+struct IRValue;
typedef unsigned int IROpFlags;
enum : IROpFlags
@@ -75,34 +53,6 @@ enum IROp : int16_t
};
-#if 0
-enum IRPseudoOp
-{
- kIRPseudoOp_Pos = -1000,
- kIRPseudoOp_PreInc,
- kIRPseudoOp_PreDec,
- kIRPseudoOp_PostInc,
- kIRPseudoOp_PostDec,
- kIRPseudoOp_Sequence,
- kIRPseudoOp_AddAssign,
- kIRPseudoOp_SubAssign,
- kIRPseudoOp_MulAssign,
- kIRPseudoOp_DivAssign,
- kIRPseudoOp_ModAssign,
- kIRPseudoOp_AndAssign,
- kIRPseudoOp_OrAssign,
- kIRPseudoOp_XorAssign ,
- kIRPseudoOp_LshAssign,
- kIRPseudoOp_RshAssign,
- kIRPseudoOp_Assign,
- kIRPseudoOp_BitNot,
- kIRPseudoOp_And,
- kIRPseudoOp_Or,
-
- kIROp_Invalid = -1,
-};
-#endif
-
IROp findIROp(char const* name);
// A logical operation/opcode in the IR
@@ -125,12 +75,12 @@ IROpInfo getIROpInfo(IROp op);
// A use of another value/inst within an IR operation
struct IRUse
{
+ // The value that is being used
+ IRValue* usedValue;
+
// The value that is doing the using.
IRInst* user;
- // The value that is being used
- IRInst* usedValue;
-
// The next use of the same value
IRUse* nextUse;
@@ -138,7 +88,7 @@ struct IRUse
// so that we can simplify updates.
IRUse** prevLink;
- void init(IRInst* user, IRInst* usedValue);
+ void init(IRInst* user, IRValue* usedValue);
};
enum IRDecorationOp : uint16_t
@@ -162,52 +112,61 @@ struct IRDecoration
IRDecorationOp op;
};
-typedef uint32_t IRInstID;
+// Use AST-level types directly to represent the
+// types of IR instructions/values
+typedef Type IRType;
+
+struct IRBlock;
-// In the IR, almost *everything* is an instruction,
-// in order to make the representation as uniform as possible.
-struct IRInst
+// Base class for values in the IR
+struct IRValue
{
// The operation that this value represents
IROp op;
- // A unique ID to represent the op when printing
- // (or zero to indicate that the value of this
- // op isn't special).
- IRInstID id;
+ // The type of the result value of this instruction,
+ // or `null` to indicate that the instruction has
+ // no value.
+ RefPtr<Type> type;
- // The total number of arguments of this instruction
- // (including the type)
- uint32_t argCount;
-
- // The parent of this instruction.
- // This will often be a basic block, but we
- // allow instructions to nest in more general ways.
- IRParentInst* parent;
-
- // The next and previous instructions in the same parent block
- IRInst* nextInst;
- IRInst* prevInst;
+ Type* getType() { return type; }
- // The first use of this value (start of a linked list)
- IRUse* firstUse;
-
- // The linked list of decorations attached to this instruction
+ // The linked list of decorations attached to this value
IRDecoration* firstDecoration;
+ // Look up a decoration in the list of decorations
IRDecoration* findDecorationImpl(IRDecorationOp op);
-
template<typename T>
T* findDecoration()
{
return (T*) findDecorationImpl(IRDecorationOp(T::kDecorationOp));
}
+ // The first use of this value (start of a linked list)
+ IRUse* firstUse;
- // The type of this value
- IRUse type;
- IRType* getType() { return (IRType*) type.usedValue; }
+};
+
+// Instructions are values that can be executed,
+// and which take other values as operands
+struct IRInst : IRValue
+{
+ // The total number of arguments of this instruction.
+ //
+ // TODO: We shouldn't need to allocate this on
+ // all instructions. Instead we should have
+ // instructions that need "vararg" support to
+ // allocate this field ahead of the `this`
+ // pointer.
+ uint32_t argCount;
+
+ // The basic block that contains this instruction.
+ IRBlock* parentBlock;
+
+ // The next and previous instructions in the same parent block
+ IRInst* nextInst;
+ IRInst* prevInst;
UInt getArgCount()
{
@@ -216,20 +175,16 @@ struct IRInst
IRUse* getArgs();
- IRInst* getArg(UInt index)
+ IRValue* getArg(UInt index)
{
return getArgs()[index].usedValue;
}
};
-// This type alias exists because I waffled on the name for a bit.
-// All existing uses of `IRValue` should move to `IRInst`
-typedef IRInst IRValue;
-
typedef int64_t IRIntegerValue;
typedef double IRFloatingPointValue;
-struct IRConstant : IRInst
+struct IRConstant : IRValue
{
union
{
@@ -241,16 +196,6 @@ struct IRConstant : IRInst
} u;
};
-// Representation of a type at the IR level.
-// Such a type may not correspond to the high-level-language notion
-// of a type as used by the front end.
-//
-// Note that types are instructions in the IR, so that operations
-// may take type operands as easily as values.
-struct IRType : IRInst
-{
-};
-
// A instruction that ends a basic block (usually because of control flow)
struct IRTerminatorInst : IRInst
{};
@@ -258,23 +203,20 @@ struct IRTerminatorInst : IRInst
bool isTerminatorInst(IROp op);
bool isTerminatorInst(IRInst* inst);
-
-// A parent instruction contains a sequence of other instructions
+// A function parameter is owned by a basic block, and represents
+// either an incoming function parameter (in the entry block), or
+// a value that flows from one SSA block to another (in a non-entry
+// block).
//
-struct IRParentInst : IRInst
+// In each case, the basic idea is that a block is a "label with
+// arguments."
+struct IRParam : IRValue
{
- // The first and last instruction in the container (or NULL in
- // the case that the container is empty).
- //
- IRInst* firstChild;
- IRInst* lastChild;
-};
+ IRParam* nextParam;
+ IRParam* prevParam;
-// A function parameter is represented by an instruction
-// in the entry block of a function.
-struct IRParam : IRInst
-{
- IRParam* getNextParam();
+ IRParam* getNextParam() { return nextParam; }
+ IRParam* getPrevParam() { return prevParam; }
};
// A basic block is a parent instruction that adds the constraint
@@ -282,48 +224,97 @@ struct IRParam : IRInst
// no function declarations, or nested blocks). We also expect
// that the previous/next instruction are always a basic block.
//
-struct IRBlock : IRParentInst
+struct IRBlock : IRValue
{
+ // Linked list of the instructions contained in this block
+ //
// Note that in a valid program, every block must end with
// a "terminator" instruction, so these should be non-NULL,
- // and `last` should actually be an `IRTerminatorInst`.
+ // and `lastInst` should actually be an `IRTerminatorInst`.
+ IRInst* firstInst;
+ IRInst* lastInst;
- IRBlock* getPrevBlock() { return (IRBlock*) prevInst; }
- IRBlock* getNextBlock() { return (IRBlock*) nextInst; }
+ IRInst* getFirstInst() { return firstInst; }
+ IRInst* getLastInst() { return lastInst; }
- IRFunc* getParent() { return (IRFunc*)parent; }
+ // Links for the list of basic blocks in the parent function
+ IRBlock* prevBlock;
+ IRBlock* nextBlock;
+
+ IRBlock* getPrevBlock() { return prevBlock; }
+ IRBlock* getNextBlock() { return nextBlock; }
+
+ // Linked list of parameters of this block
+ IRParam* firstParam;
+ IRParam* lastParam;
+
+ IRParam* getFirstParam() { return firstParam; }
+ IRParam* getLastParam() { return lastParam; }
+ void addParam(IRParam* param);
+
+ // The parent function that contains this block
+ IRFunc* parentFunc;
+
+ IRFunc* getParent() { return parentFunc; }
- IRParam* getFirstParam();
};
-struct IRFuncType;
+// For right now, we will represent the type of
+// an IR function using the type of the AST
+// function from which it was created.
+//
+// TODO: need to do this better.
+typedef FuncType IRFuncType;
+
+struct IRGlobalValue : IRValue
+{};
// A function is a parent to zero or more blocks of instructions.
//
// A function is itself a value, so that it can be a direct operand of
// an instruction (e.g., a call).
-struct IRFunc : IRParentInst
+struct IRFunc : IRGlobalValue
{
- IRFuncType* getType() { return (IRFuncType*) type.usedValue; }
+ // The type of the IR-level function
+ IRFuncType* getType() { return (IRFuncType*) type.Ptr(); }
+
+ // The mangled name, for a function
+ // that should have linkage.
+ String mangledName;
- IRType* getResultType();
+ // Convenience accessors for working with the
+ // function's type.
+ Type* getResultType();
UInt getParamCount();
- IRType* getParamType(UInt index);
+ Type* getParamType(UInt index);
- IRBlock* getFirstBlock() { return (IRBlock*) firstChild; }
- IRBlock* getLastBlock() { return (IRBlock*) lastChild; }
+ // The list of basic blocks in this function
+ IRBlock* firstBlock = nullptr;
+ IRBlock* lastBlock = nullptr;
+ IRBlock* getFirstBlock() { return firstBlock; }
+ IRBlock* getLastBlock() { return lastBlock; }
+
+ // Add a block to the end of this function.
+ void addBlock(IRBlock* block);
+
+ // Convenience accessor for the IR parameters,
+ // which are actually the parameters of the first
+ // block.
IRParam* getFirstParam();
};
// A module is a parent to functions, global variables, types, etc.
-struct IRModule : IRParentInst
+struct IRModule : RefObject
{
// The designated entry-point function, if any
IRFunc* entryPoint;
- // A special counter used to assign logical ids to instructions in this module.
- IRInstID idCounter;
+ // A list of all the functions and other
+ // global values declared in this module.
+ List<IRGlobalValue*> globalValues;
+
+ // TODO: need a symbol of all the global variables too
};
void printSlangIRAssembly(StringBuilder& builder, IRModule* module);
diff --git a/source/slang/lower-to-ir.cpp b/source/slang/lower-to-ir.cpp
index 06ad66bc4..7fce0c385 100644
--- a/source/slang/lower-to-ir.cpp
+++ b/source/slang/lower-to-ir.cpp
@@ -79,7 +79,7 @@ struct SubscriptInfo : ExtendedValueInfo
struct BoundSubscriptInfo : ExtendedValueInfo
{
DeclRef<SubscriptDecl> declRef;
- IRType* type;
+ RefPtr<Type> type;
List<IRValue*> args;
UInt genericArgCount;
};
@@ -218,8 +218,8 @@ struct BoundMemberInfo : ExtendedValueInfo
//
struct SwizzledLValueInfo : ExtendedValueInfo
{
- // IR-level The type of the expression.
- IRType* type;
+ // The type of the expression.
+ RefPtr<Type> type;
// The base expression (this should be an l-value)
LoweredValInfo base;
@@ -355,7 +355,7 @@ LoweredValInfo emitCompoundAssignOp(
auto leftVal = builder->emitLoad(leftPtr);
- IRInst* innerArgs[] = { leftVal, rightVal };
+ IRValue* innerArgs[] = { leftVal, rightVal };
auto innerOp = builder->emitIntrinsicInst(type, op, 2, innerArgs);
builder->emitStore(leftPtr, innerOp);
@@ -363,23 +363,31 @@ LoweredValInfo emitCompoundAssignOp(
return LoweredValInfo::ptr(leftPtr);
}
-IRInst* getOneValOfType(
+IRValue* getOneValOfType(
IRGenContext* context,
IRType* type)
{
- switch(type->op)
+ if (auto basicType = dynamic_cast<BasicExpressionType*>(type))
{
- case kIROp_Int32Type:
- case kIROp_UInt32Type:
- return context->irBuilder->getIntValue(type, 1);
+ switch (basicType->baseType)
+ {
+ case BaseType::Int:
+ case BaseType::UInt:
+ case BaseType::UInt64:
+ return context->irBuilder->getIntValue(type, 1);
- case kIROp_Float32Type:
- return context->irBuilder->getFloatValue(type, 1.0);
+ case BaseType::Float:
+ case BaseType::Double:
+ return context->irBuilder->getFloatValue(type, 1.0);
- default:
- SLANG_UNEXPECTED("inc/dec type");
- return nullptr;
+ default:
+ break;
+ }
}
+ // TODO: should make sure to handle vector and matrix types here
+
+ SLANG_UNEXPECTED("inc/dec type");
+ return nullptr;
}
LoweredValInfo emitPreOp(
@@ -396,9 +404,9 @@ LoweredValInfo emitPreOp(
auto preVal = builder->emitLoad(argPtr);
- IRInst* oneVal = getOneValOfType(context, type);
+ IRValue* oneVal = getOneValOfType(context, type);
- IRInst* innerArgs[] = { preVal, oneVal };
+ IRValue* innerArgs[] = { preVal, oneVal };
auto innerOp = builder->emitIntrinsicInst(type, op, 2, innerArgs);
builder->emitStore(argPtr, innerOp);
@@ -420,9 +428,9 @@ LoweredValInfo emitPostOp(
auto preVal = builder->emitLoad(argPtr);
- IRInst* oneVal = getOneValOfType(context, type);
+ IRValue* oneVal = getOneValOfType(context, type);
- IRInst* innerArgs[] = { preVal, oneVal };
+ IRValue* innerArgs[] = { preVal, oneVal };
auto innerOp = builder->emitIntrinsicInst(type, op, 2, innerArgs);
builder->emitStore(argPtr, innerOp);
@@ -647,10 +655,7 @@ struct LoweredTypeInfo
Simple,
};
- union
- {
- IRType* type;
- };
+ RefPtr<IRType> type;
Flavor flavor;
LoweredTypeInfo()
@@ -743,6 +748,35 @@ LoweredValInfo lowerDecl(
DeclBase* decl,
Layout* layout);
+IRType* getIntType(
+ IRGenContext* context)
+{
+ return context->getSession()->getBuiltinType(BaseType::Int);
+}
+
+// Get a pointer type to the given element type
+RefPtr<PtrType> getPtrType(
+ IRGenContext* context,
+ IRType* valueType)
+{
+ return context->getSession()->getPtrType(valueType);
+}
+
+RefPtr<IRFuncType> getFuncType(
+ IRGenContext* context,
+ UInt paramCount,
+ RefPtr<IRType> const* paramTypes,
+ IRType* resultType)
+{
+ RefPtr<FuncType> funcType = new FuncType();
+ funcType->resultType = resultType;
+ for (UInt pp = 0; pp < paramCount; ++pp)
+ {
+ funcType->paramTypes.Add(paramTypes[pp]);
+ }
+ return funcType;
+}
+
//
struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, LoweredTypeInfo>
@@ -761,7 +795,7 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
// TODO: it is a bit messy here that the `ConstantIntVal` representation
// has no notion of a *type* associated with the value...
- auto type = getBuilder()->getBaseType(BaseType::Int);
+ auto type = getIntType(context);
return LoweredValInfo::simple(getBuilder()->getIntValue(type, val->value));
}
@@ -772,16 +806,7 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
LoweredTypeInfo visitFuncType(FuncType* type)
{
- LoweredValInfo loweredFunc = ensureDecl(context, type->declRef);
- auto loweredFuncVal = getSimpleVal(context, loweredFunc);
-
- // HACK: deal with the case where the decl might not
- // lower to anything, and so we don't have a type to
- // work with.
- if (!loweredFuncVal)
- return LoweredTypeInfo();
-
- return loweredFuncVal->getType();
+ return LoweredTypeInfo(type);
}
void addGenericArgs(List<IRValue*>* ioArgs, DeclRefBase declRef)
@@ -799,12 +824,16 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
LoweredTypeInfo visitDeclRefType(DeclRefType* type)
{
+#if 1
+ // TODO: is there actually anything to be done at this point?
+ return LoweredTypeInfo(type);
+#else
// We need to detect builtin/intrinsic types here, since they should map to custom modifiers
// We need to catch builtin/intrinsic types here
if( auto intrinsicTypeMod = type->declRef.getDecl()->FindModifier<IntrinsicTypeModifier>() )
{
auto builder = getBuilder();
- auto intType = builder->getBaseType(BaseType::Int);
+ auto intType = getIntType(context);
//
List<IRValue*> irArgs;
for( auto val : intrinsicTypeMod->irOperands )
@@ -831,61 +860,32 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
default:
SLANG_UNIMPLEMENTED_X("type lowering");
}
-
+#endif
}
LoweredTypeInfo visitBasicExpressionType(BasicExpressionType* type)
{
- return getBuilder()->getBaseType(type->baseType);
+ return LoweredTypeInfo(type);
}
LoweredTypeInfo visitVectorExpressionType(VectorExpressionType* type)
{
- auto irElementType = lowerSimpleType(context, type->elementType);
- auto irElementCount = lowerSimpleVal(context, type->elementCount);
-
- return getBuilder()->getVectorType(irElementType, irElementCount);
+ return LoweredTypeInfo(type);
}
LoweredTypeInfo visitMatrixExpressionType(MatrixExpressionType* type)
{
- auto irElementType = lowerSimpleType(context, type->getElementType());
- auto irRowCount = lowerSimpleVal(context, type->getRowCount());
- auto irColumnCount = lowerSimpleVal(context, type->getColumnCount());
-
- return getBuilder()->getMatrixType(irElementType, irRowCount, irColumnCount);
+ return LoweredTypeInfo(type);
}
- LoweredTypeInfo getArrayType(
- LoweredTypeInfo const& loweredElementType,
- IRValue* irElementCount)
+ LoweredTypeInfo visitArrayExpressionType(ArrayExpressionType* type)
{
- switch (loweredElementType.flavor)
- {
- case LoweredTypeInfo::Flavor::Simple:
- return getBuilder()->getArrayType(
- loweredElementType.type,
- irElementCount);
- break;
-
- default:
- SLANG_UNEXPECTED("array element type");
- break;
- }
+ return LoweredTypeInfo(type);
}
- LoweredTypeInfo visitArrayExpressionType(ArrayExpressionType* type)
+ LoweredTypeInfo visitIRBasicBlockType(IRBasicBlockType* type)
{
- auto loweredElementType = lowerType(context, type->baseType);
- if (auto elementCount = type->ArrayLength)
- {
- auto irElementCount = lowerSimpleVal(context, elementCount);
- return getArrayType(loweredElementType, irElementCount);
- }
- else
- {
- return getArrayType(loweredElementType, nullptr);
- }
+ return LoweredTypeInfo(type);
}
};
@@ -1017,9 +1017,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
if (auto fieldDeclRef = declRef.As<StructField>())
{
// Okay, easy enough: we have a reference to a field of a struct type...
-
- auto loweredField = ensureDecl(context, fieldDeclRef);
- return extractField(loweredType, loweredBase, loweredField);
+ return extractField(loweredType, loweredBase, fieldDeclRef);
}
else if (auto callableDeclRef = declRef.As<CallableDecl>())
{
@@ -1045,14 +1043,12 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
// in order for a dereference to make senese, so we just
// need to extract the value type from that pointer here.
//
- auto loweredBaseVal = getSimpleVal(context, loweredBase);
- auto loweredBaseType = loweredBaseVal->getType();
- switch( loweredBaseType->op )
- {
- case kIROp_PtrType:
- // TODO: should we enumerate these explicitly?
- case kIROp_ConstantBufferType:
- case kIROp_TextureBufferType:
+ IRValue* loweredBaseVal = getSimpleVal(context, loweredBase);
+ RefPtr<Type> loweredBaseType = loweredBaseVal->getType();
+
+ if (loweredBaseType->As<PointerLikeType>()
+ || loweredBaseType->As<PtrType>())
+ {
// Note that we do *not* perform an actual `load` operation
// here, but rather just use the pointer value to construct
// an appropriate `LoweredValInfo` representing the underlying
@@ -1064,8 +1060,9 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
// and is just a bit of pointer math.
//
return LoweredValInfo::ptr(loweredBaseVal);
-
- default:
+ }
+ else
+ {
SLANG_UNIMPLEMENTED_X("codegen for deref expression");
return LoweredValInfo();
}
@@ -1472,7 +1469,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
case LoweredValInfo::Flavor::Ptr:
return LoweredValInfo::ptr(
builder->emitElementAddress(
- builder->getPtrType(getSimpleType(type)),
+ getPtrType(context, getSimpleType(type)),
baseVal.val,
indexVal));
@@ -1484,9 +1481,9 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
}
LoweredValInfo extractField(
- LoweredTypeInfo fieldType,
- LoweredValInfo base,
- LoweredValInfo field)
+ LoweredTypeInfo fieldType,
+ LoweredValInfo base,
+ DeclRef<StructField> field)
{
switch (base.flavor)
{
@@ -1497,7 +1494,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
getBuilder()->emitFieldExtract(
getSimpleType(fieldType),
irBase,
- (IRStructField*) getSimpleVal(context, field)));
+ getBuilder()->getDeclRefVal(field)));
}
break;
@@ -1509,9 +1506,9 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
IRValue* irBasePtr = base.val;
return LoweredValInfo::ptr(
getBuilder()->emitFieldAddress(
- getBuilder()->getPtrType(getSimpleType(fieldType)),
+ getPtrType(context, getSimpleType(fieldType)),
irBasePtr,
- (IRStructField*) getSimpleVal(context, field)));
+ getBuilder()->getDeclRefVal(field)));
}
break;
}
@@ -1598,7 +1595,7 @@ struct RValueExprLoweringVisitor : ExprLoweringVisitorBase<RValueExprLoweringVis
auto builder = getBuilder();
- auto irIntType = builder->getBaseType(BaseType::Int);
+ auto irIntType = getIntType(context);
UInt elementCount = (UInt)expr->elementCount;
IRValue* irElementIndices[4];
@@ -1661,34 +1658,22 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor>
void insertBlock(IRBlock* block)
{
auto builder = getBuilder();
- auto parent = builder->parentInst;
- IRBlock* prevBlock = nullptr;
- IRFunc* parentFunc = nullptr;
-
- switch (parent->op)
- {
- case kIROp_Block:
- prevBlock = (IRBlock*)parent;
- parentFunc = prevBlock->getParent();
- break;
-
- default:
- SLANG_UNEXPECTED("bad parent kind for block");
- return;
- }
+ auto prevBlock = builder->block;
+ auto parentFunc = prevBlock->parentFunc;
// If the previous block doesn't already have
// a terminator instruction, then be sure to
// emit a branch to the new block.
- if (!isTerminatorInst(prevBlock->lastChild))
+ if (!isTerminatorInst(prevBlock->lastInst))
{
builder->emitBranch(block);
}
- builder->parentInst = parentFunc;
- builder->addInst(block);
- builder->parentInst = block;
+ parentFunc->addBlock(block);
+
+ builder->func = parentFunc;
+ builder->block = block;
}
// Start a new block at the current location.
@@ -2025,8 +2010,8 @@ top:
auto loweredBase = swizzleInfo->base;
// Load from the base value:
- IRInst* irLeftVal = getSimpleVal(context, loweredBase);
- auto irRightVal = getSimpleVal(context, right);
+ IRValue* irLeftVal = getSimpleVal(context, loweredBase);
+ IRValue* irRightVal = getSimpleVal(context, right);
// Now apply the swizzle
IRInst* irSwizzled = builder->emitSwizzleSet(
@@ -2066,7 +2051,7 @@ top:
emitCallToDeclRef(
context,
- builder->getVoidType(),
+ context->getSession()->getVoidType(),
setterDeclRef,
allArgs,
subscriptInfo->genericArgCount);
@@ -2153,8 +2138,67 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
return LoweredValInfo();
}
+ bool isGlobalVarDecl(VarDeclBase* decl)
+ {
+ auto parent = decl->ParentDecl;
+ if (dynamic_cast<ModuleDecl*>(parent))
+ {
+ // Variable declared at global scope? -> Global.
+ return true;
+ }
+
+ return false;
+ }
+
+ LoweredValInfo lowerGlobalVarDecl(VarDeclBase* decl)
+ {
+ auto varType = lowerSimpleType(context, decl->getType());
+
+ IRAddressSpace addressSpace = kIRAddressSpace_Default;
+ if (decl->HasModifier<HLSLGroupSharedModifier>())
+ {
+ addressSpace = kIRAddressSpace_GroupShared;
+ }
+
+ auto builder = getBuilder();
+ auto irGlobal = builder->createGlobalVar(varType);
+
+ if (decl)
+ {
+ builder->addHighLevelDeclDecoration(irGlobal, decl);
+ }
+
+ if (auto layout = getLayout())
+ {
+ builder->addLayoutDecoration(irGlobal, layout);
+ }
+
+ // A global variable's SSA value is a *pointer* to
+ // the underlying storage.
+ auto globalVal = LoweredValInfo::ptr(irGlobal);
+ context->shared->declValues.Add(
+ DeclRef<VarDeclBase>(decl, nullptr),
+ globalVal);
+
+ if( auto initExpr = decl->initExpr )
+ {
+ // TODO: need to handle global with initializer!
+ }
+
+ getBuilder()->getModule()->globalValues.Add(irGlobal);
+
+ return globalVal;
+ }
+
LoweredValInfo visitVarDeclBase(VarDeclBase* decl)
{
+ // Detect global (or effectively global) variables
+ // and handle them differently.
+ if (isGlobalVarDecl(decl))
+ {
+ return lowerGlobalVarDecl(decl);
+ }
+
// A user-defined variable declaration will usually turn into
// an `alloca` operation for the variable's storage,
// plus some code to initialize it and then store to the variable.
@@ -2219,6 +2263,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
LoweredValInfo visitAggTypeDecl(AggTypeDecl* decl)
{
+#if 0
// User-defined aggregate type: need to translate into
// a corresponding IR aggregate type.
@@ -2257,6 +2302,10 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
builder->addInst(irStruct);
return LoweredValInfo::simple(irStruct);
+#else
+ // TODO: What is there to do with a `struct` type?
+ return LoweredValInfo();
+#endif
}
// Sometimes we need to refer to a declaration the way that it would be specialized
@@ -2592,7 +2641,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
}
void trySetMangledName(
- IRInst* inst,
+ IRFunc* irFunc,
Decl* decl)
{
// We want to generate a mangled name for the given declaration and attach
@@ -2605,8 +2654,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
String mangledName = getMangledName(decl);
- auto decoration = getBuilder()->addDecoration<IRMangledNameDecoration>(inst);
- decoration->mangledName = mangledName;
+ irFunc->mangledName = mangledName;
}
@@ -2649,11 +2697,11 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
// need to create an IR function here
IRFunc* irFunc = subBuilder->createFunc();
- subBuilder->parentInst = irFunc;
+ subBuilder->func = irFunc;
trySetMangledName(irFunc, decl);
- List<IRType*> paramTypes;
+ List<RefPtr<Type>> paramTypes;
// We first need to walk the generic parameters (if any)
// because these will influence the declared type of
@@ -2662,6 +2710,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
for( auto genericParamDecl : parameterLists.genericParams )
{
UInt genericParamIndex = genericParamCounter++;
+#if 0
if( auto genericTypeParamDecl = dynamic_cast<GenericTypeParamDecl*>(genericParamDecl) )
{
// In the logical type for the function, a generic
@@ -2675,10 +2724,11 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
// to the appropriate generic parameter position.
IRType* irParameterType = context->irBuilder->getGenericParameterType(genericParamIndex);
- LoweredValInfo LoweredValInfo = LoweredValInfo::simple(irParameterType);
+ LoweredValInfo LoweredValInfo = LoweredValInfo::type(irParameterType);
subContext->shared->declValues[makeDeclRef(genericTypeParamDecl)] = LoweredValInfo;
}
else
+#endif
{
// TODO: handle the other cases here.
SLANG_UNEXPECTED("generic parameter kind");
@@ -2702,7 +2752,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
//
// TODO: Is this the best representation we can use?
- auto irPtrType = subBuilder->getPtrType(irParamType);
+ auto irPtrType = getPtrType(context, irParamType);
paramTypes.Add(irPtrType);
}
}
@@ -2726,14 +2776,15 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
// Instead, a setter always returns `void`
//
- irResultType = getBuilder()->getVoidType();
+ irResultType = context->getSession()->getVoidType();
}
- auto irFuncType = getBuilder()->getFuncType(
+ auto irFuncType = getFuncType(
+ context,
paramTypes.Count(),
paramTypes.Buffer(),
irResultType);
- irFunc->type.init(irFunc, irFuncType);
+ irFunc->type = irFuncType;
if (!decl->Body)
{
@@ -2753,7 +2804,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
// This is a function definition, so we need to actually
// construct IR for the body...
IRBlock* entryBlock = subBuilder->emitBlock();
- subBuilder->parentInst = entryBlock;
+ subBuilder->block = entryBlock;
UInt paramTypeIndex = 0;
for( auto paramInfo : parameterLists.params )
@@ -2771,7 +2822,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
//
// TODO: Is this the best representation we can use?
- auto irPtrType = (IRPtrType*)irParamType;
+ auto irPtrType = irParamType.As<PtrType>();
IRParam* irParamPtr = subBuilder->emitParam(irPtrType);
if(auto paramDecl = paramInfo.decl)
@@ -2829,14 +2880,21 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
// We need to carefully add a terminator instruction to the end
// of the body, in case the user didn't do so.
- if (!isTerminatorInst(subContext->irBuilder->parentInst->lastChild))
+ if (!isTerminatorInst(subContext->irBuilder->block->lastInst))
{
- if (irResultType->op == kIROp_VoidType)
+ if (irResultType->Equals(context->getSession()->getVoidType()))
{
+ // `void`-returning function can get an implicit
+ // return on exit of the body statement.
subContext->irBuilder->emitReturn();
}
else
{
+ // Value-returning function is expected to `return`
+ // on every control-flow path. We need to enforce
+ // this by putting an `unreachable` terminator here,
+ // and then emit a dataflow error if this block
+ // can't be eliminated.
SLANG_UNEXPECTED("Needed a return here");
subContext->irBuilder->emitReturn();
}
@@ -2845,7 +2903,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
getBuilder()->addHighLevelDeclDecoration(irFunc, decl);
- getBuilder()->addInst(irFunc);
+ getBuilder()->getModule()->globalValues.Add(irFunc);
return LoweredValInfo::simple(irFunc);
}
@@ -2878,7 +2936,6 @@ LoweredValInfo ensureDecl(
IRBuilder subIRBuilder;
subIRBuilder.shared = context->irBuilder->shared;
- subIRBuilder.parentInst = subIRBuilder.shared->module;
IRGenContext subContext = *context;
@@ -2997,15 +3054,14 @@ IRModule* lowerEntryPointToIR(
SharedIRBuilder sharedBuilderStorage;
SharedIRBuilder* sharedBuilder = &sharedBuilderStorage;
sharedBuilder->module = nullptr;
+ sharedBuilder->session = entryPoint->compileRequest->mSession;
IRBuilder builderStorage;
IRBuilder* builder = &builderStorage;
builder->shared = sharedBuilder;
- builder->parentInst = nullptr;
IRModule* module = builder->createModule();
sharedBuilder->module = module;
- builder->parentInst = module;
context->irBuilder = builder;
diff --git a/source/slang/lower.cpp b/source/slang/lower.cpp
index 85f19f14c..6d848062c 100644
--- a/source/slang/lower.cpp
+++ b/source/slang/lower.cpp
@@ -708,6 +708,12 @@ struct LoweringVisitor
return result;
}
+ RefPtr<Type> visitIRBasicBlockType(IRBasicBlockType* type)
+ {
+ return type;
+ }
+
+
RefPtr<Type> visitErrorType(ErrorType* type)
{
return type;
@@ -732,9 +738,16 @@ struct LoweringVisitor
RefPtr<Type> visitFuncType(FuncType* type)
{
- RefPtr<FuncType> loweredType = getFuncType(
- getSession(),
- translateDeclRef(DeclRef<Decl>(type->declRef)).As<CallableDecl>());
+ RefPtr<FuncType> loweredType = new FuncType();
+ loweredType->resultType = lowerType(type->resultType);
+ for (auto paramType : type->paramTypes)
+ {
+ auto loweredParamType = lowerType(paramType);
+
+ // TODO: it seems like this step needs to scalarize
+ // in the case where a parameter type is a tuple...
+ loweredType->paramTypes.Add(loweredParamType);
+ }
return loweredType;
}
diff --git a/source/slang/syntax.cpp b/source/slang/syntax.cpp
index 475802cb1..4dbb6c05b 100644
--- a/source/slang/syntax.cpp
+++ b/source/slang/syntax.cpp
@@ -216,6 +216,9 @@ void Type::accept(IValVisitor* visitor, void* extra)
overloadedType = new OverloadGroupType();
overloadedType->setSession(this);
+
+ irBasicBlockType = new IRBasicBlockType();
+ irBasicBlockType->setSession(this);
}
Type* Session::getBoolType()
@@ -268,6 +271,29 @@ void Type::accept(IValVisitor* visitor, void* extra)
return errorType;
}
+ Type* Session::getIRBasicBlockType()
+ {
+ return irBasicBlockType;
+ }
+
+ RefPtr<PtrType> Session::getPtrType(
+ RefPtr<Type> valueType)
+ {
+ auto genericDecl = findMagicDecl(
+ this, "PtrType").As<GenericDecl>();
+ auto typeDecl = genericDecl->inner;
+
+ auto substitutions = new Substitutions();
+ substitutions->genericDecl = genericDecl.Ptr();
+ substitutions->args.Add(valueType);
+
+ auto declRef = DeclRef<Decl>(typeDecl.Ptr(), substitutions);
+
+ return DeclRefType::Create(
+ this,
+ declRef)->As<PtrType>();
+ }
+
SyntaxClass<RefObject> Session::findSyntaxClass(Name* name)
{
SyntaxClass<RefObject> syntaxClass;
@@ -545,7 +571,26 @@ void Type::accept(IValVisitor* visitor, void* extra)
else
{
- SLANG_UNEXPECTED("unhandled type");
+ auto classInfo = session->findSyntaxClass(
+ session->getNamePool()->getName(magicMod->name));
+ if (!classInfo.classInfo)
+ {
+ SLANG_UNEXPECTED("unhandled type");
+ }
+
+ auto type = classInfo.createInstance();
+ if (!type)
+ {
+ SLANG_UNEXPECTED("constructor failure");
+ }
+
+ auto declRefType = dynamic_cast<DeclRefType*>(type);
+ if (!declRefType)
+ {
+ SLANG_UNEXPECTED("expected a declaration reference type");
+ }
+ declRefType->declRef = declRef;
+ return declRefType;
}
}
else
@@ -578,6 +623,28 @@ void Type::accept(IValVisitor* visitor, void* extra)
return (int)(int64_t)(void*)this;
}
+ // IRBasicBlockType
+
+ String IRBasicBlockType::ToString()
+ {
+ return "Block";
+ }
+
+ bool IRBasicBlockType::EqualsImpl(Type * /*type*/)
+ {
+ return false;
+ }
+
+ Type* IRBasicBlockType::CreateCanonicalType()
+ {
+ return this;
+ }
+
+ int IRBasicBlockType::GetHashCode()
+ {
+ return (int)(int64_t)(void*)this;
+ }
+
// InitializerListType
String InitializerListType::ToString()
@@ -653,18 +720,43 @@ void Type::accept(IValVisitor* visitor, void* extra)
String FuncType::ToString()
{
- // TODO: a better approach than this
- if (declRef)
- return getText(declRef.GetName());
- else
- return "/* unknown FuncType */";
+ StringBuilder sb;
+ sb << "(";
+ UInt paramCount = getParamCount();
+ for (UInt pp = 0; pp < paramCount; ++pp)
+ {
+ if (pp != 0) sb << ", ";
+ sb << getParamType(pp)->ToString();
+ }
+ sb << ") -> ";
+ sb << getResultType()->ToString();
+ return sb.ProduceString();
}
bool FuncType::EqualsImpl(Type * type)
{
if (auto funcType = type->As<FuncType>())
{
- return declRef == funcType->declRef;
+ auto paramCount = getParamCount();
+ auto otherParamCount = funcType->getParamCount();
+ if (paramCount != otherParamCount)
+ return false;
+
+ for (UInt pp = 0; pp < paramCount; ++pp)
+ {
+ auto paramType = getParamType(pp);
+ auto otherParamType = funcType->getParamType(pp);
+ if (!paramType->Equals(otherParamType))
+ return false;
+ }
+
+ if(!resultType->Equals(funcType->resultType))
+ return false;
+
+ // TODO: if we ever introduce other kinds
+ // of qualification on function types, we'd
+ // want to consider it here.
+ return true;
}
return false;
}
@@ -676,7 +768,16 @@ void Type::accept(IValVisitor* visitor, void* extra)
int FuncType::GetHashCode()
{
- return declRef.GetHashCode();
+ int hashCode = getResultType()->GetHashCode();
+ UInt paramCount = getParamCount();
+ hashCode = combineHash(hashCode, Slang::GetHashCode(paramCount));
+ for (UInt pp = 0; pp < paramCount; ++pp)
+ {
+ hashCode = combineHash(
+ hashCode,
+ getParamType(pp)->GetHashCode());
+ }
+ return hashCode;
}
// TypeType
@@ -782,6 +883,13 @@ void Type::accept(IValVisitor* visitor, void* extra)
return this->declRef.substitutions->args[2].As<IntVal>().Ptr();
}
+ // PtrTypeBase
+
+ Type* PtrTypeBase::getValueType()
+ {
+ return this->declRef.substitutions->args[0].As<Type>().Ptr();
+ }
+
// GenericParamIntVal
bool GenericParamIntVal::EqualsVal(Val* val)
@@ -1161,7 +1269,13 @@ void Type::accept(IValVisitor* visitor, void* extra)
{
auto funcType = new FuncType();
funcType->setSession(session);
- funcType->declRef = declRef;
+
+ funcType->resultType = GetResultType(declRef);
+ for (auto pp : GetParameters(declRef))
+ {
+ funcType->paramTypes.Add(GetType(pp));
+ }
+
return funcType;
}
diff --git a/source/slang/syntax.h b/source/slang/syntax.h
index 8d6a1edbc..f5e22d9b5 100644
--- a/source/slang/syntax.h
+++ b/source/slang/syntax.h
@@ -74,6 +74,29 @@ namespace Slang
kConversionCost_ScalarToVector = 1,
};
+ // TODO(tfoley): We should ditch this enumeration
+ // and just use the IR opcodes that represent these
+ // types directly. The one major complication there
+ // is that the order of the enum values currently
+ // matters, since it determines promotion rank.
+ // We either need to keep that restriction, or
+ // look up promotion rank by some other means.
+ //
+ enum class BaseType
+ {
+ // Note(tfoley): These are ordered in terms of promotion rank, so be vareful when messing with this
+
+ Void = 0,
+ Bool,
+ Int,
+ UInt,
+ UInt64,
+ Half,
+ Float,
+ Double,
+ };
+
+
// Forward-declare all syntax classes
#define SYNTAX_CLASS(NAME, BASE, ...) class NAME;
diff --git a/source/slang/type-defs.h b/source/slang/type-defs.h
index 7883b5c42..7fbefc368 100644
--- a/source/slang/type-defs.h
+++ b/source/slang/type-defs.h
@@ -41,6 +41,20 @@ protected:
)
END_SYNTAX_CLASS()
+// The type of a reference to a basic block
+// in our IR
+SYNTAX_CLASS(IRBasicBlockType, Type)
+RAW(
+public:
+ virtual String ToString() override;
+
+protected:
+ virtual bool EqualsImpl(Type * type) override;
+ virtual Type* CreateCanonicalType() override;
+ virtual int GetHashCode() override;
+)
+END_SYNTAX_CLASS()
+
// A type that takes the form of a reference to some declaration
SYNTAX_CLASS(DeclRefType, Type)
DECL_FIELD(DeclRef<Decl>, declRef)
@@ -221,6 +235,8 @@ END_SYNTAX_CLASS()
// Other cases of generic types known to the compiler
SYNTAX_CLASS(BuiltinGenericType, DeclRefType)
SYNTAX_FIELD(RefPtr<Type>, elementType)
+
+ RAW(Type* getElementType() { return elementType; })
END_SYNTAX_CLASS()
// Types that behave like pointers, in that they can be
@@ -353,6 +369,33 @@ protected:
)
END_SYNTAX_CLASS()
+// Base class for types that map down to
+// simple pointers as part of code generation.
+SYNTAX_CLASS(PtrTypeBase, DeclRefType)
+RAW(
+ // Get the type of the pointed-to value.
+ Type* getValueType();
+)
+END_SYNTAX_CLASS()
+
+// A true (user-visible) pointer type, e.g., `T*`
+SYNTAX_CLASS(PtrType, PtrTypeBase)
+END_SYNTAX_CLASS()
+
+// A type that represents the behind-the-scenes
+// logical pointer that is passed for an `out`
+// or `in out` parameter
+SYNTAX_CLASS(OutTypeBase, PtrTypeBase)
+END_SYNTAX_CLASS()
+
+// The type for an `out` parameter, e.g., `out T`
+SYNTAX_CLASS(OutType, OutTypeBase)
+END_SYNTAX_CLASS()
+
+// The type for an `in out` parameter, e.g., `in out T`
+SYNTAX_CLASS(InOutType, OutTypeBase)
+END_SYNTAX_CLASS()
+
// A type alias of some kind (e.g., via `typedef`)
SYNTAX_CLASS(NamedExpressionType, Type)
DECL_FIELD(DeclRef<TypeDefDecl>, declRef)
@@ -375,17 +418,27 @@ protected:
)
END_SYNTAX_CLASS()
-// Function types are currently used for references to symbols that name
-// either ordinary functions, or "component functions."
-// We do not directly store a representation of the type, and instead
-// use a reference to the symbol to stand in for its logical type
+// A function type is defined by its parameter types
+// and its result type.
SYNTAX_CLASS(FuncType, Type)
- DECL_FIELD(DeclRef<CallableDecl>, declRef)
+
+ // TODO: We may want to preserve parameter names
+ // in the list here, just so that we can print
+ // out friendly names when printing a function
+ // type, even if they don't affect the actual
+ // semantic type underneath.
+
+ FIELD(List<RefPtr<Type>>, paramTypes)
+ FIELD(RefPtr<Type>, resultType)
RAW(
FuncType()
{}
+ UInt getParamCount() { return paramTypes.Count(); }
+ Type* getParamType(UInt index) { return paramTypes[index]; }
+ Type* getResultType() { return resultType; }
+
virtual String ToString() override;
protected:
virtual bool EqualsImpl(Type * type) override;
diff --git a/source/slang/vm.cpp b/source/slang/vm.cpp
index 65ecefc01..d3e44a947 100644
--- a/source/slang/vm.cpp
+++ b/source/slang/vm.cpp
@@ -11,28 +11,43 @@
namespace Slang
{
-// Representation of a type during VM execution
-struct VMTypeImpl
+struct VMValImpl
{
+ // opcode used to construct the value
uint32_t op;
+};
+
+// Representation of a type during VM execution
+struct VMTypeImpl : VMValImpl
+{
+ // number of arguments to the type
+ uint32_t argCount;
// Size and alignment of instances of this
// type.
UInt size;
UInt alignment;
+
+ // operands follow
};
-struct VMType
+
+struct VMVal
{
- VMTypeImpl* impl;
+ VMValImpl* impl;
- UInt getSize() { return impl->size; }
- UInt getAlignment() { return impl->alignment; }
+ VMValImpl* getImpl() { return impl; }
+};
+
+struct VMType : VMVal
+{
+ VMTypeImpl* getImpl() { return (VMTypeImpl*) impl; }
+ UInt getSize() { return getImpl()->size; }
+ UInt getAlignment() { return getImpl()->alignment; }
};
struct VMPtrTypeImpl : VMTypeImpl
{
VMType base;
- uint32_t addressSpace;
};
struct VMReg
@@ -47,20 +62,20 @@ struct VMReg
struct VMConst
{
// Type of the constant
- VMType type;
+ VMType type;
// Operand address to use
void* ptr;
};
-struct VMFrame;
+struct VMModule;
// Information about a function after it has been
// loaded into the VM.
struct VMFunc
{
// The parent module that this function belongs to
- VMFrame* module;
+ VMModule* module;
BCFunc* bcFunc;
VMReg* regs;
@@ -84,8 +99,14 @@ struct VMFrame
// Registers are stored after this point.
};
-struct VMModule : VMFunc
+struct VM;
+
+struct VMModule
{
+ BCModule* bcModule;
+ VM* vm;
+ void** symbols;
+ VMType* types;
};
UInt decodeUInt(BCOp** ioPtr)
@@ -93,13 +114,28 @@ UInt decodeUInt(BCOp** ioPtr)
BCOp* ptr = *ioPtr;
UInt value = *ptr++;
- if( value >= 128 )
+ if( value < 128 )
{
- SLANG_UNEXPECTED("deal with this later");
+ *ioPtr = ptr;
+ return value;
}
- *ioPtr = ptr;
- return value;
+ // Slower path for variable-length encoding
+
+ UInt result = 0;
+ for(;;)
+ {
+ value = value & 0x7F;
+ result = (result << 7) | value;
+
+ if(value < 127)
+ {
+ *ioPtr = ptr;
+ return value;
+ }
+
+ value = *ptr++;
+ }
}
Int decodeSInt(BCOp** ioPtr)
@@ -194,9 +230,15 @@ T& decodeOperand(VMFrame* frame, BCOp** ioIP)
return *decodeOperandPtr<T>(frame, ioIP);
}
+VMType decodeType(VMFrame* frame, BCOp** ioIP)
+{
+ UInt id = decodeUInt(ioIP);
+ return frame->func->module->types[id];
+}
+
VMFunc* loadVMFunc(
BCFunc* bcFunc,
- VMFrame* vmModuleInstance);
+ VMModule* vmModule);
struct VMSizeAlign
{
@@ -231,9 +273,30 @@ VMSizeAlign getVMSymbolSize(BCSymbol* symbol)
return result;
}
+VMType getType(
+ VMModule* vmModule,
+ uint32_t typeID)
+{
+ return vmModule->types[typeID];
+}
+
+void* getGlobalPtr(
+ VMModule* vmModule,
+ uint32_t globalID)
+{
+ return vmModule->symbols[globalID];
+}
+
+VMType getGlobalType(
+ VMModule* vmModule,
+ uint32_t globalID)
+{
+ return getType(vmModule, vmModule->bcModule->symbols[globalID]->typeID);
+}
+
VMFunc* loadVMFunc(
BCFunc* bcFunc,
- VMFrame* vmModuleInstance)
+ VMModule* vmModule)
{
UInt regCount = bcFunc->regCount;
UInt constCount = bcFunc->constCount;
@@ -245,7 +308,7 @@ VMFunc* loadVMFunc(
VMReg* vmRegs = (VMReg*) (vmFunc + 1);
VMConst* vmConsts = (VMConst*) (vmRegs + regCount);
- vmFunc->module = vmModuleInstance;
+ vmFunc->module = vmModule;
vmFunc->bcFunc = bcFunc;
vmFunc->regs = vmRegs;
vmFunc->consts = vmConsts;
@@ -254,31 +317,14 @@ VMFunc* loadVMFunc(
for( UInt rr = 0; rr < regCount; ++rr )
{
BCReg* bcReg = &bcFunc->regs[rr];
- auto typeGlobalID = bcReg->typeGlobalID;
-
- // HACK: when we are loading a module itself, we might
- // not yet know the size for the things it defines
- // (since the module itself might define the type of
- // one of its symbols), so for now we hack it and
- // assume everything at module level is 16 bytes or less.
- //
- // TODO: this also seems like it will cause problems
- // in other contexts (any time the type of a register
- // would depend on an earlier instruction in the same
- // scope) so this needs careful thought.
- VMType vmType = { nullptr };
- UInt regSize = 16;
- UInt regAlign = 8;
-
- if (vmModuleInstance)
- {
- // We expect the type to come from the outer module, so
- // that we can allocate space for it as we go.
- vmType = *(VMType*)getRegPtrImpl(vmModuleInstance, typeGlobalID);
+ auto bcTypeID = bcReg->typeID;
- regSize = vmType.getSize();
- regAlign = vmType.getAlignment();
- }
+ // We expect the type to come from the outer module, so
+ // that we can allocate space for it as we go.
+ auto vmType = getType(vmModule, bcTypeID);
+
+ auto regSize = vmType.getSize();
+ auto regAlign = vmType.getAlignment();
offset = (offset + (regAlign-1)) & ~(regAlign-1);
@@ -292,10 +338,32 @@ VMFunc* loadVMFunc(
for( UInt cc = 0; cc < constCount; ++cc )
{
- auto globalID = bcFunc->consts[cc].globalID;
- auto globalPtr = getRegPtrImpl(vmModuleInstance, globalID);
- vmFunc->consts[cc].ptr = globalPtr;
- vmFunc->consts[cc].type = vmModuleInstance->func->regs[globalID].type;
+ BCConst bcConst = bcFunc->consts[cc];
+ switch( bcConst.flavor )
+ {
+ case kBCConstFlavor_GlobalSymbol:
+ {
+ auto globalID = bcConst.id;
+ vmFunc->consts[cc].ptr = &vmModule->symbols[globalID];
+ vmFunc->consts[cc].type = getGlobalType(vmModule, globalID);
+ }
+ break;
+
+ case kBCConstFlavor_Constant:
+ {
+ auto constID = bcConst.id;
+ auto constInfo = &vmModule->bcModule->constants[constID];
+ vmFunc->consts[cc].ptr = constInfo->ptr;
+ #if 0
+ fprintf(stderr, "CONSANT[%d] : [%p]\n", (int)cc, vmFunc->consts[cc].ptr);
+ fprintf(stderr, "BC [%p] : %d\n", &constInfo->ptr, (int)constInfo->ptr.rawVal);
+ #endif
+ vmFunc->consts[cc].type = getType(vmModule, constInfo->typeID);
+ }
+ break;
+ }
+
+
}
return vmFunc;
@@ -368,7 +436,7 @@ void dumpVMFrame(VMFrame* vmFrame)
case kIROp_PtrType:
{
- fprintf(stderr, ": Ptr<???> = 0x%p", *(void**)regData);
+ fprintf(stderr, ": Ptr<?> = [%p]", *(void**)regData);
}
break;
@@ -416,7 +484,186 @@ struct VMThread
void resumeThread(
VMThread* vmThread);
-VMFrame* loadVMModuleInstance(
+void computeTypeSizeAlign(
+ VMTypeImpl* impl)
+{
+ UInt size = 0;
+ UInt alignment = 0;
+ switch(impl->op)
+ {
+ case kIROp_VoidType:
+ size = 0;
+ break;
+
+ case kIROp_BoolType:
+ size = 1;
+ break;
+
+ case kIROp_Int32Type:
+ case kIROp_UInt32Type:
+ case kIROp_Float32Type:
+ size = 4;
+ break;
+
+ case kIROp_FuncType:
+ case kIROp_PtrType:
+ case kIROp_readWriteStructuredBufferType:
+ case kIROp_structuredBufferType:
+ size = sizeof(void*);
+ break;
+
+ default:
+ SLANG_UNIMPLEMENTED_X("type sizing");
+ impl->size = 0;
+ break;
+ }
+
+ if(!alignment)
+ alignment = size;
+ if(!alignment)
+ alignment = 1;
+
+ impl->size = size;
+ impl->alignment = alignment;
+}
+
+VMType getType(
+ VM* vm,
+ VMTypeImpl* typeImpl)
+{
+ // TODO: need to look up an existing type that matches...
+
+ UInt argCount = typeImpl->argCount;
+ UInt size = sizeof(VMTypeImpl) + argCount*sizeof(VMType);
+
+ VMTypeImpl* impl = (VMTypeImpl*) malloc(size);
+ memcpy(impl, typeImpl, size);
+
+ computeTypeSizeAlign(impl);
+
+ VMType type;
+ type.impl = impl;
+ return type;
+}
+
+VMVal getVal(
+ VMModule* vmModule,
+ UInt index)
+{
+ return vmModule->types[index];
+}
+
+VMType loadVMType(
+ VMModule* vmModule,
+ BCType* bcType)
+{
+ // Need to load type from BC format to VM
+ IROp op = (IROp) bcType->op;
+ switch(bcType->op)
+ {
+ case kIROp_PtrType:
+ {
+ // TODO: need to do some caching!
+ BCPtrType* bcPtrType = (BCPtrType*) bcType;
+
+ VMPtrTypeImpl vmPtrTypeImpl;
+ vmPtrTypeImpl.op = op;
+ vmPtrTypeImpl.argCount = 1;
+ vmPtrTypeImpl.size = sizeof(void*);
+ vmPtrTypeImpl.alignment = sizeof(void*);
+ vmPtrTypeImpl.base = getType(vmModule, bcPtrType->valueType->id);
+
+ auto vmPtrType = getType(vmModule->vm, &vmPtrTypeImpl);
+ return vmPtrType;
+ }
+ break;
+
+ default:
+ {
+ UInt argCount = bcType->argCount;
+
+ UInt size = sizeof(VMTypeImpl) + argCount * sizeof(VMVal);
+
+ VMTypeImpl* impl = (VMTypeImpl*) alloca(size);
+ memset(impl, 0, size);
+ impl->op = bcType->op;
+ impl->argCount = argCount;
+
+ VMVal* args = (VMVal*) (impl + 1);
+ for(UInt aa = 0; aa < argCount; ++aa)
+ {
+ args[aa] = getVal(vmModule, bcType->getArg(aa)->id);
+ }
+
+ return getType(vmModule->vm, impl);
+ }
+
+ SLANG_UNEXPECTED("unimplemented");
+ return VMType();
+ break;
+ }
+}
+
+void* allocateImpl(VM* vm, UInt size, UInt align)
+{
+ void* ptr = malloc(size);
+ memset(ptr, 0, size);
+ return ptr;
+}
+
+void* allocate(VM* vm, VMType type)
+{
+ return allocateImpl(vm, type.getSize(), type.getAlignment());
+}
+
+template<typename T>
+T* allocate(VM* vm)
+{
+ return allocateImpl(vm, sizeof(T), alignof(T));
+}
+
+void* loadVMSymbol(
+ VMModule* vmModule,
+ BCSymbol* bcSymbol)
+{
+ // Need to load type from BC format to VM
+
+ auto vm = vmModule->vm;
+
+ switch(bcSymbol->op)
+ {
+ case kIROp_global_var:
+ {
+ auto type = getType(vmModule, bcSymbol->typeID);
+ assert(type.impl->op == kIROp_PtrType);
+
+ VMPtrTypeImpl* ptrTypeImpl = (VMPtrTypeImpl*) type.impl;
+ auto valueType = ptrTypeImpl->base;
+
+ void* varValue = allocate(vm, valueType);
+ void** varPtr = (void**) allocate(vm, type);
+
+
+ *varPtr = varValue;
+
+ return varPtr;
+ }
+ break;
+
+ case kIROp_Func:
+ {
+ auto bcFunc = (BCFunc*) bcSymbol;
+ VMFunc* vmFunc = loadVMFunc(bcFunc, vmModule);
+ return vmFunc;
+ }
+ break;
+
+ default:
+ return nullptr;
+ }
+}
+
+VMModule* loadVMModuleInstance(
VM* vm,
void const* bytecode,
size_t bytecodeSize)
@@ -425,45 +672,63 @@ VMFrame* loadVMModuleInstance(
BCModule* bcModule = bcHeader->module;
- UInt vmModuleSize = sizeof(VMModule) + bcModule->regCount * sizeof(VMReg);
+ UInt symbolCount = bcModule->symbolCount;
+ UInt typeCount = bcModule->typeCount;
- VMModule* vmModule = (VMModule*) loadVMFunc(bcModule, nullptr);
+ UInt vmModuleSize = sizeof(VMModule)
+ + symbolCount * sizeof(void*)
+ + typeCount * sizeof(VMType);
- // Create a frame to store the loaded symbols, and execute it
- // to initialize them.
- VMFrame* vmModuleInstance = createFrame(vmModule);
- vmModuleInstance->parent = nullptr;
- vmModule->module = vmModuleInstance;
+ VMModule* vmModule = (VMModule*)malloc(vmModuleSize);
+ memset(vmModule, 0, vmModuleSize);
+ void** vmSymbols = (void**)(vmModule + 1);
+ VMType* vmTypes = (VMType*)(vmSymbols + symbolCount);
- VMThread thread;
- thread.frame = vmModuleInstance;
+ vmModule->bcModule = bcModule;
+ vmModule->vm = vm;
+ vmModule->symbols = vmSymbols;
+ vmModule->types = vmTypes;
- resumeThread(&thread);
+ // Initialize types before symbols, since the symbols
+ // will all have types...
+ for(UInt tt = 0; tt < typeCount; ++tt)
+ {
+ BCType* bcType = bcModule->types[tt];
+ vmTypes[tt] = loadVMType(vmModule, bcType);
+ }
+
+ // Now we need to initialize all the VM-level symbols
+ // from their BC-level equivalents.
+ for(UInt ss = 0; ss < symbolCount; ++ss)
+ {
+ BCSymbol* bcSymbol = bcModule->symbols[ss];
+ vmSymbols[ss] = loadVMSymbol(vmModule, bcSymbol);
+ }
- return vmModuleInstance;
+ return vmModule;
}
void* findGlobalSymbolPtr(
- VMFrame* moduleInstance,
+ VMModule* module,
char const* name)
{
// Okay, we need to search through the available
- // "registers" looking for one that gives us a
- // name match.
- //
- BCFunc* bcFunc = moduleInstance->func->bcFunc;
- UInt regCount = bcFunc->regCount;
- for (UInt rr = 0; rr < regCount; ++rr)
+ // symbols, looking for one that gives us a name
+ // match.
+
+ BCModule* bcModule = module->bcModule;
+ UInt symbolCount = bcModule->symbolCount;
+ for(UInt ss = 0; ss < symbolCount; ++ss)
{
- BCReg* bcReg = &bcFunc->regs[rr];
+ BCSymbol* bcSymbol = bcModule->symbols[ss];
- char const* symbolName = bcReg->name;
+ char const* symbolName = bcSymbol->name;
if (!symbolName)
continue;
if(strcmp(symbolName, name) == 0)
- return getRegPtrImpl(moduleInstance, rr);
+ return getGlobalPtr(module, ss);
}
return nullptr;
@@ -516,182 +781,11 @@ void resumeThread(
auto op = (IROp) decodeUInt(&ip);
switch( op )
{
- case kIROp_TypeType:
- {
- // The type of types
- Int argCount = decodeUInt(&ip);
- void* arg0Ptr = decodeOperandPtr<void>(frame, &ip);
- VMType* destPtr = decodeOperandPtr<VMType>(frame, &ip);
-
- auto typeImpl = new VMTypeImpl();
- typeImpl->op = op;
- typeImpl->size = sizeof(VMType);
- typeImpl->alignment = alignof(VMType);
-
- VMType type = { typeImpl };
- *destPtr = type;
- }
- break;
-
- case kIROp_BlockType:
- case kIROp_Int32Type:
- case kIROp_UInt32Type:
- case kIROp_Float32Type:
- case kIROp_BoolType:
- case kIROp_VoidType:
- case kIROp_FuncType: // TODO: we should in principle handle function types here
- {
- // Case to handle types without arguments.
- UInt argCount = decodeUInt(&ip);
- for( UInt aa = 0; aa < argCount; ++aa )
- {
- void* argPtr = decodeOperandPtr<void>(frame, &ip);
- }
- VMType* destPtr = decodeOperandPtr<VMType>(frame, &ip);
-
- auto typeImpl = new VMTypeImpl();
- typeImpl->op = op;
-
- UInt size = 1;
- UInt align = 0;
- switch (op)
- {
- case kIROp_BlockType: size = sizeof(void*); break;
- case kIROp_Int32Type: size = sizeof(int32_t); break;
- case kIROp_UInt32Type: size = sizeof(uint32_t); break;
- case kIROp_Float32Type: size = sizeof(float); break;
- case kIROp_BoolType: size = sizeof(bool); break;
- case kIROp_VoidType: size = 0; align = 1; break;
- default:
- break;
- }
- if (!align) align = size;
- typeImpl->size = size;
- typeImpl->alignment = align;
-
- VMType type = { typeImpl };
- *destPtr = type;
- }
- break;
-
- case kIROp_PtrType:
- {
- // Case to handle types without arguments.
- UInt argCount = decodeUInt(&ip);
- VMType* typeTypePtr = decodeOperandPtr<VMType>(frame, &ip);
- VMType baseType = decodeOperand<VMType>(frame, &ip);
- int32_t addressSpace = decodeOperand<int32_t>(frame, &ip);
- VMType* destPtr = decodeOperandPtr<VMType>(frame, &ip);
-
-
-
- auto typeImpl = new VMPtrTypeImpl();
- typeImpl->op = op;
- typeImpl->size = sizeof(void*);
- typeImpl->alignment = alignof(void*);
- typeImpl->base = baseType;
- typeImpl->addressSpace = addressSpace;
-
- VMType type = { typeImpl };
- *destPtr = type;
- }
- break;
-
- case kIROp_readWriteStructuredBufferType:
- case kIROp_structuredBufferType:
- {
- // Case to handle types without arguments.
- UInt argCount = decodeUInt(&ip);
- VMType* typeTypePtr = decodeOperandPtr<VMType>(frame, &ip);
- VMType baseType = decodeOperand<VMType>(frame, &ip);
- VMType* destPtr = decodeOperandPtr<VMType>(frame, &ip);
-
-
- // TODO: give these their own representations!
- auto typeImpl = new VMPtrTypeImpl();
- typeImpl->op = op;
- typeImpl->base = baseType;
- typeImpl->size = sizeof(void*);
- typeImpl->alignment = sizeof(void*);
-
- VMType type = { typeImpl };
- *destPtr = type;
- }
- break;
-
-
- case kIROp_IntLit:
- {
- VMType type = decodeOperand<VMType>(frame, &ip);
- UInt uVal = decodeUInt(&ip);
- void* destPtr = decodeOperandPtr<void>(frame, &ip);
-
- switch( type.impl->op )
- {
- case kIROp_Int32Type:
- *(int32_t*)destPtr = int32_t(uVal);
- break;
-
- case kIROp_UInt32Type:
- *(uint32_t*)destPtr = uint32_t(uVal);
- break;
-
- default:
- SLANG_UNEXPECTED("integer type");
- break;
- }
-
- }
- break;
-
- case kIROp_FloatLit:
- {
- VMType type = decodeOperand<VMType>(frame, &ip);
-
- static const UInt size = sizeof(IRFloatingPointValue);
- IRFloatingPointValue value;
- memcpy(&value, ip, size);
- ip += size;
- void* destPtr = decodeOperandPtr<void>(frame, &ip);
-
- switch( type.impl->op )
- {
- case kIROp_Float32Type:
- *(float*)destPtr = float(value);
- break;
-
- default:
- SLANG_UNEXPECTED("float type");
- break;
- }
- }
- break;
-
- case kIROp_boolConst:
- {
- bool val = (*ip++) != 0;
- bool* destPtr = decodeOperandPtr<bool>(frame, &ip);
- *destPtr = val;
- }
- break;
-
- case kIROp_Func:
- {
- UInt nestedID = decodeUInt(&ip);
- void* destPtr = decodeOperandPtr<void>(frame, &ip);
-
- BCSymbol* bcSymbol = frame->func->bcFunc->nestedSymbols[nestedID];
- BCFunc* bcFunc = (BCFunc*)bcSymbol;
- VMFunc* vmFunc = loadVMFunc(bcFunc, frame->func->module);
-
- *(VMFunc**)destPtr = vmFunc;
- }
- break;
-
case kIROp_Var:
{
// This instruction represents the `alloca` for a variable of some type.
+ VMType type = decodeType(frame, &ip);
UInt argCount = decodeUInt(&ip);
void* argPtrs[16] = { 0 };
for( UInt aa = 0; aa < argCount; ++aa )
@@ -714,10 +808,17 @@ void resumeThread(
case kIROp_Store:
{
// An ordinary memory store
- VMType type = decodeOperand<VMType>(frame, &ip);
+ VMType type = decodeType(frame, &ip);
void* dest = decodeOperand<void*>(frame, &ip);
void* src = decodeOperandPtr<void>(frame, &ip);
+#if 0
+ fprintf(stderr, "STORE *[%p] = [%p] // size: %d\n",
+ dest,
+ src,
+ (int) type.getSize());
+#endif
+
memcpy(dest, src, type.getSize());
}
break;
@@ -725,7 +826,7 @@ void resumeThread(
case kIROp_Load:
{
// An ordinary memory store
- VMType type = decodeOperand<VMType>(frame, &ip);
+ VMType type = decodeType(frame, &ip);
void* src = decodeOperand<void*>(frame, &ip);
void* dest = decodeOperandPtr<void>(frame, &ip);
@@ -735,6 +836,7 @@ void resumeThread(
case kIROp_BufferLoad:
{
+ VMType type = decodeType(frame, &ip);
UInt argCount = decodeUInt(&ip);
void* argPtrs[16] = { 0 };
for( UInt aa = 0; aa < argCount; ++aa )
@@ -745,9 +847,8 @@ void resumeThread(
void* dest = decodeOperandPtr<void>(frame, &ip);
- VMType type = *(VMType*)argPtrs[0];
- char* bufferData = *(char**)argPtrs[1];
- uint32_t index = *(uint32_t*)argPtrs[2];
+ char* bufferData = *(char**)argPtrs[0];
+ uint32_t index = *(uint32_t*)argPtrs[1];
auto size = type.getSize();
char* elementData = bufferData + index*size;
@@ -757,9 +858,9 @@ void resumeThread(
case kIROp_BufferStore:
{
+ VMType resultType = decodeType(frame, &ip);
UInt argCount = decodeUInt(&ip);
- VMType* resultTypePtr = decodeOperandPtr<VMType>(frame, &ip);
char* bufferData = decodeOperand<char*>(frame, &ip);
uint32_t index = decodeOperand<uint32_t>(frame, &ip);
@@ -774,8 +875,10 @@ void resumeThread(
break;
case kIROp_Call:
{
+ VMType type = decodeType(frame, &ip);
UInt operandCount = decodeUInt(&ip);
- VMType type = decodeOperand<VMType>(frame, &ip);
+
+ // First operand is the callee function
VMFunc* func = decodeOperand<VMFunc*>(frame, &ip);
// Okay, we need to create a frame to prepare the call
@@ -784,7 +887,7 @@ void resumeThread(
// Remaining arguments should populate the
// first N registers of the callee
- UInt argCount = operandCount - 2;
+ UInt argCount = operandCount - 1;
for( UInt aa = 0; aa < argCount; ++aa )
{
void* argPtr = decodeOperandPtr<void>(frame, &ip);
@@ -836,8 +939,8 @@ void resumeThread(
case kIROp_ReturnVal:
{
+ VMType instType = decodeType(frame, &ip);
UInt argCount = decodeUInt(&ip);
- void* typePtr = decodeOperandPtr<void>(frame, &ip);
void* argPtr = decodeOperandPtr<void>(frame, &ip);
VMFrame* oldFrame = frame;
@@ -868,8 +971,8 @@ void resumeThread(
// For now our encoding is very regular, so we can decode without
// knowing too much about an instruction...
+ VMType type = decodeType(frame, &ip);
UInt argCount = decodeUInt(&ip);
- void* typePtr = decodeOperandPtr<void>(frame, &ip);
Int destinationBlock = decodeSInt(&ip);
for( UInt aa = 2; aa < argCount; ++aa )
{
@@ -892,8 +995,8 @@ void resumeThread(
// For now our encoding is very regular, so we can decode without
// knowing too much about an instruction...
+ VMType type = decodeType(frame, &ip);
UInt argCount = decodeUInt(&ip);
- void* typePtr = decodeOperandPtr<void>(frame, &ip);
bool* condition = decodeOperandPtr<bool>(frame, &ip);
Int trueBlockID = decodeSInt(&ip);
Int falseBlockID = decodeSInt(&ip);
@@ -917,9 +1020,9 @@ void resumeThread(
// For now our encoding is very regular, so we can decode without
// knowing too much about an instruction...
+ VMType resultType = decodeType(frame, &ip);
UInt argCount = decodeUInt(&ip);
void* argPtrs[16] = { 0 };
- VMType resultType = decodeOperand<VMType>(frame, &ip);
auto leftOpnd = decodeOperandPtrAndType(frame, &ip);
auto type = leftOpnd.type;
auto leftPtr = leftOpnd.ptr;
@@ -942,8 +1045,8 @@ void resumeThread(
case kIROp_Mul:
{
+ VMType type = decodeType(frame, &ip);
UInt argCount = decodeUInt(&ip);
- VMType type = decodeOperand<VMType>(frame, &ip);
void* leftPtr = decodeOperandPtr<void>(frame, &ip);
void* rightPtr = decodeOperandPtr<void>(frame, &ip);
@@ -964,8 +1067,8 @@ void resumeThread(
case kIROp_Sub:
{
+ VMType type = decodeType(frame, &ip);
UInt argCount = decodeUInt(&ip);
- VMType type = decodeOperand<VMType>(frame, &ip);
void* leftPtr = decodeOperandPtr<void>(frame, &ip);
void* rightPtr = decodeOperandPtr<void>(frame, &ip);
@@ -989,6 +1092,7 @@ void resumeThread(
// For now our encoding is very regular, so we can decode without
// knowing too much about an instruction...
+ VMType type = decodeType(frame, &ip);
UInt argCount = decodeUInt(&ip);
void* argPtrs[16] = { 0 };
for( UInt aa = 0; aa < argCount; ++aa )
@@ -1039,7 +1143,7 @@ SLANG_API void* SlangVMModule_findGlobalSymbolPtr(
char const* name)
{
return (SlangVMFunc*) Slang::findGlobalSymbolPtr(
- (Slang::VMFrame*) module,
+ (Slang::VMModule*) module,
name);
}
diff --git a/tests/ir/loop.slang.expected b/tests/ir/loop.slang.expected
index a0f8de432..e34bb68cc 100644
--- a/tests/ir/loop.slang.expected
+++ b/tests/ir/loop.slang.expected
@@ -2,77 +2,80 @@ result code = 0
standard error = {
}
standard output = {
-let %63 : Ptr<Array<Vec<Float32,4>,64>,1> = var()
-let %96 : Ptr<StructuredBuffer<Vec<Float32,4>>,0> = var()
-let %282 : Ptr<RWStructuredBuffer<Vec<Float32,4>>,0> = var()
-func @_S04mainp3 : (Int32, Int32, Int32) -> Void
+ir_global_var %1 : Ptr<vector<float,4>[64]>;
+
+ir_global_var %2 : Ptr<StructuredBuffer<vector<float,4>>>;
+
+ir_global_var %3 : Ptr<RWStructuredBuffer<vector<float,4>>>;
+
+ir_func @_S04mainp3 : (uint, uint, uint) -> void
{
-block %14( param %15 : Int32,
- param %24 : Int32,
- param %32 : Int32)
-:
- let %21 : Ptr<Int32,0> = var()
- store(%21, %15)
- let %29 : Ptr<Int32,0> = var()
- store(%29, %24)
- let %37 : Ptr<Int32,0> = var()
- store(%37, %32)
- let %64 : Int32 = load(%29)
- let %69 : Ptr<Vec<Float32,4>,0> = getElementPtr(%63, %64)
- let %97 : StructuredBuffer<Vec<Float32,4>> = load(%96)
- let %100 : Int32 = load(%21)
- let %101 : Vec<Float32,4> = bufferLoad(%97, %100)
- store(%69, %101)
- let %110 : Ptr<Int32,0> = var()
- let %118 : Int32 = construct(1)
- store(%110, %118)
- loop(%123, %129, %132)
+block %4(
+ param %5 : uint,
+ param %6 : uint,
+ param %7 : uint):
+ let %8 : Ptr<uint> = var()
+ store(%8, %5)
+ let %9 : Ptr<uint> = var()
+ store(%9, %6)
+ let %10 : Ptr<uint> = var()
+ store(%10, %7)
+ let %11 : uint = load(%9)
+ let %12 : Ptr<vector<float,4>> = getElementPtr(%1, %11)
+ let %13 : StructuredBuffer<vector<float,4>> = load(%2)
+ let %14 : uint = load(%8)
+ let %15 : vector<float,4> = bufferLoad(%13, %14)
+ store(%12, %15)
+ let %16 : Ptr<uint> = var()
+ let %17 : uint = construct(1)
+ store(%16, %17)
+ loop(%18, %19, %20)
-block %123:
- let %139 : Int32 = load(%110)
- let %148 : Int32 = construct(64)
- let %149 : Bool = cmpLT(%139, %148)
- loopTest(%149, %126, %129)
+block %18:
+ let %21 : uint = load(%16)
+ let %22 : uint = construct(64)
+ let %23 : bool = cmpLT(%21, %22)
+ loopTest(%23, %24, %19)
-block %126:
+block %24:
GroupMemoryBarrierWithGroupSync()
- let %174 : Int32 = load(%29)
- let %179 : Ptr<Vec<Float32,4>,0> = getElementPtr(%63, %174)
- let %184 : Ptr<Vec<Float32,4>,0> = var()
- let %185 : Vec<Float32,4> = load(%179)
- store(%184, %185)
- let %204 : Int32 = load(%29)
- let %207 : Int32 = load(%110)
- let %208 : Int32 = sub(%204, %207)
- let %213 : Ptr<Vec<Float32,4>,0> = getElementPtr(%63, %208)
- let %214 : Vec<Float32,4> = load(%213)
- let %215 : Vec<Float32,4> = load(%184)
- let %216 : Vec<Float32,4> = add(%215, %214)
- store(%184, %216)
- let %219 : Vec<Float32,4> = load(%184)
- store(%179, %219)
- unconditionalBranch(%132)
+ let %25 : uint = load(%9)
+ let %26 : Ptr<vector<float,4>> = getElementPtr(%1, %25)
+ let %27 : Ptr<vector<float,4>> = var()
+ let %28 : vector<float,4> = load(%26)
+ store(%27, %28)
+ let %29 : uint = load(%9)
+ let %30 : uint = load(%16)
+ let %31 : uint = sub(%29, %30)
+ let %32 : Ptr<vector<float,4>> = getElementPtr(%1, %31)
+ let %33 : vector<float,4> = load(%32)
+ let %34 : vector<float,4> = load(%27)
+ let %35 : vector<float,4> = add(%34, %33)
+ store(%27, %35)
+ let %36 : vector<float,4> = load(%27)
+ store(%26, %36)
+ unconditionalBranch(%20)
-block %132:
- let %232 : Ptr<Int32,0> = var()
- let %233 : Int32 = load(%110)
- store(%232, %233)
- let %244 : Int32 = construct(1)
- let %245 : Int32 = load(%232)
- let %246 : Int32 = shl(%245, %244)
- store(%232, %246)
- let %249 : Int32 = load(%232)
- store(%110, %249)
- unconditionalBranch(%123)
+block %20:
+ let %37 : Ptr<uint> = var()
+ let %38 : uint = load(%16)
+ store(%37, %38)
+ let %39 : uint = construct(1)
+ let %40 : uint = load(%37)
+ let %41 : uint = shl(%40, %39)
+ store(%37, %41)
+ let %42 : uint = load(%37)
+ store(%16, %42)
+ unconditionalBranch(%18)
-block %129:
+block %19:
GroupMemoryBarrierWithGroupSync()
- let %283 : RWStructuredBuffer<Vec<Float32,4>> = load(%282)
- let %286 : Int32 = load(%21)
- let %300 : Ptr<Vec<Float32,4>,0> = getElementPtr(%63, 0)
- let %301 : Vec<Float32,4> = load(%300)
- bufferStore(%283, %286, %301)
+ let %43 : RWStructuredBuffer<vector<float,4>> = load(%3)
+ let %44 : uint = load(%8)
+ let %45 : Ptr<vector<float,4>> = getElementPtr(%1, 0)
+ let %46 : vector<float,4> = load(%45)
+ bufferStore(%43, %44, %46)
return_void()
}
}
diff --git a/tools/eval-test/main.cpp b/tools/eval-test/main.cpp
index f8736a07b..9fb6f94a3 100644
--- a/tools/eval-test/main.cpp
+++ b/tools/eval-test/main.cpp
@@ -80,15 +80,15 @@ int main(
bytecode,
bytecodeSize);
- SlangVMFunc* vmFunc = *(SlangVMFunc**)SlangVMModule_findGlobalSymbolPtr(
+ SlangVMFunc* vmFunc = (SlangVMFunc*)SlangVMModule_findGlobalSymbolPtr(
vmModule,
"main");
- int32_t*& inputArg = **(int32_t***)SlangVMModule_findGlobalSymbolPtr(
+ int32_t*& inputArg = *(int32_t**)SlangVMModule_findGlobalSymbolPtr(
vmModule,
"input");
- int32_t*& outputArg = **(int32_t***)SlangVMModule_findGlobalSymbolPtr(
+ int32_t*& outputArg = *(int32_t**)SlangVMModule_findGlobalSymbolPtr(
vmModule,
"output");