summaryrefslogtreecommitdiffstats
path: root/source
diff options
context:
space:
mode:
Diffstat (limited to 'source')
-rw-r--r--source/slang/bytecode.cpp555
-rw-r--r--source/slang/bytecode.h102
-rw-r--r--source/slang/check.cpp27
-rw-r--r--source/slang/compiler.h8
-rw-r--r--source/slang/core.meta.slang11
-rw-r--r--source/slang/core.meta.slang.h12
-rw-r--r--source/slang/emit.cpp487
-rw-r--r--source/slang/ir-inst-defs.h29
-rw-r--r--source/slang/ir-insts.h236
-rw-r--r--source/slang/ir.cpp1065
-rw-r--r--source/slang/ir.h261
-rw-r--r--source/slang/lower-to-ir.cpp320
-rw-r--r--source/slang/lower.cpp19
-rw-r--r--source/slang/syntax.cpp132
-rw-r--r--source/slang/syntax.h23
-rw-r--r--source/slang/type-defs.h63
-rw-r--r--source/slang/vm.cpp616
17 files changed, 2299 insertions, 1667 deletions
diff --git a/source/slang/bytecode.cpp b/source/slang/bytecode.cpp
index e412a5b94..854d9bf34 100644
--- a/source/slang/bytecode.cpp
+++ b/source/slang/bytecode.cpp
@@ -19,8 +19,8 @@ struct SharedBytecodeGenerationContext;
template<typename T>
struct BytecodeGenerationPtr
{
- SharedBytecodeGenerationContext* sharedContext;
UInt offset;
+ SharedBytecodeGenerationContext* sharedContext;
BytecodeGenerationPtr()
: sharedContext(nullptr)
@@ -49,31 +49,40 @@ struct BytecodeGenerationPtr
, offset(ptr.offset)
{}
- operator BCPtr<T>()
+ template<typename U>
+ BytecodeGenerationPtr<U> bitCast() const
+ {
+ return BytecodeGenerationPtr<U>(sharedContext, offset);
+ }
+
+ operator BCPtr<T>() const
{
return BCPtr<T>(getPtr());
}
- T* operator->()
+ T* operator->() const
{
return getPtr();
}
- T& operator*()
+ T& operator*() const
{
return *getPtr();
}
- T& operator[](UInt index)
+ T& operator[](UInt index) const
{
return getPtr()[index];
}
- BytecodeGenerationPtr<T> operator+(Int index)
+ BytecodeGenerationPtr<T> operator+(Int index) const
{
+ UInt size = sizeof(T);
+ Int delta = index * sizeof(T);
+ UInt newOffset = offset + delta;
return BytecodeGenerationPtr<T>(
sharedContext,
- offset + index*sizeof(T));
+ newOffset);
}
T* getPtr() const;
@@ -93,14 +102,27 @@ struct SharedBytecodeGenerationContext
// The final generated bytecode stream
List<uint8_t> bytecode;
- // Map from a global symbol to its global ID
- Dictionary<IRInst*, Int> mapGlobalSymbolToGLobalID;
+ // Map from an IR value to a global entity
+ // that encodes it:
+ Dictionary<IRValue*, BCConst> mapValueToGlobal;
+
+ // Types that have been emitted
+ List<BytecodeGenerationPtr<BCType>> bcTypes;
+ Dictionary<Type*, UInt> mapTypeToID;
+
+ // Compile-time constant values that need
+ // to be emitted...
+ List<IRValue*> constants;
};
struct BytecodeGenerationContext
{
SharedBytecodeGenerationContext* shared;
+ // The bytecode of the current symbol being
+ // output.
+ List<uint8_t> currentBytecode;
+
// The function that is in scope for this context
IRFunc* currentIRFunc;
@@ -110,11 +132,7 @@ struct BytecodeGenerationContext
// Map an instruction to its ID for use local
// to the current context
- Dictionary<IRInst*, Int> mapInstToLocalID;
-
- // Map an instruction to the ID for its auxiliary
- // symbol data
- Dictionary<IRInst*, UInt> mapInstToNestedID;
+ Dictionary<IRValue*, Int> mapInstToLocalID;
};
template<typename T>
@@ -169,7 +187,7 @@ void encodeUInt8(
BytecodeGenerationContext* context,
uint8_t value)
{
- context->shared->bytecode.Add(value);
+ context->currentBytecode.Add(value);
}
void encodeUInt(
@@ -220,50 +238,98 @@ void encodeSInt(
encodeUInt(context, uValue);
}
-Int getLocalID(
+BCConst getGlobalValue(
BytecodeGenerationContext* context,
- IRInst* inst)
+ IRValue* value)
{
- Int localID = 0;
- if( context->mapInstToLocalID.TryGetValue(inst, localID) )
- {
- return localID;
- }
+ BCConst bcConst;
+ if( context->shared->mapValueToGlobal.TryGetValue(value, bcConst) )
+ return bcConst;
+
+ // Next we need to check for things that can be mapped to
+ // global IDs on the fly.
- Int globalID = 0;
- if( context->shared->mapGlobalSymbolToGLobalID.TryGetValue(inst, globalID) )
+ switch( value->op )
{
- BCConst bcConst;
- bcConst.globalID = globalID;
+ case kIROp_IntLit:
+ {
+ UInt constID = context->shared->constants.Count();
+ context->shared->constants.Add(value);
- UInt remappedSymbolIndex = context->remappedGlobalSymbols.Count();
- context->remappedGlobalSymbols.Add(bcConst);
+ BCConst bcConst;
+ bcConst.flavor = kBCConstFlavor_Constant;
+ bcConst.id = constID;
- localID = ~remappedSymbolIndex;
- context->mapInstToLocalID.Add(inst, localID);
- return localID;
+ context->shared->mapValueToGlobal.Add(value, bcConst);
+
+ return bcConst;
+ }
+ break;
+
+ default:
+ break;
}
SLANG_UNEXPECTED("no ID for inst");
- return -9999;
+ bcConst.flavor = (BCConstFlavor) -1;
+ bcConst.id = -9999;
+ return bcConst;
+}
+
+Int getLocalID(
+ BytecodeGenerationContext* context,
+ IRValue* value)
+{
+ Int localID = 0;
+ if( context->mapInstToLocalID.TryGetValue(value, localID) )
+ {
+ return localID;
+ }
+
+ BCConst bcConst = getGlobalValue(context, value);
+ UInt remappedSymbolIndex = context->remappedGlobalSymbols.Count();
+ context->remappedGlobalSymbols.Add(bcConst);
+
+ localID = ~remappedSymbolIndex;
+ context->mapInstToLocalID.Add(value, localID);
+ return localID;
}
void encodeOperand(
BytecodeGenerationContext* context,
- IRInst* operand)
+ IRValue* operand)
{
auto id = getLocalID(context, operand);
encodeSInt(context, id);
}
-bool opHasResult(IRInst* inst)
+uint32_t getTypeID(
+ BytecodeGenerationContext* context,
+ Type* type);
+
+void encodeOperand(
+ BytecodeGenerationContext* context,
+ IRType* type)
+{
+ encodeUInt(context, getTypeID(context, type));
+}
+
+bool opHasResult(IRValue* inst)
{
auto type = inst->getType();
- if( !type || type->op != kIROp_VoidType )
+ if (!type) return false;
+
+ // As a bit of a hack right now, we need to check whether
+ // the function returns the distinguished `Void` type,
+ // since that is conceptually the same as "not returning
+ // a value."
+ if (auto basicType = dynamic_cast<BasicExpressionType*>(type))
{
- return true;
+ if (basicType->baseType == BaseType::Void)
+ return false;
}
- return false;
+
+ return true;
}
void generateBytecodeForInst(
@@ -282,15 +348,16 @@ void generateBytecodeForInst(
//
auto argCount = inst->getArgCount();
+ auto type = inst->getType();
encodeUInt(context, inst->op);
+ encodeOperand(context, inst->getType());
encodeUInt(context, argCount);
for( UInt aa = 0; aa < argCount; ++aa )
{
encodeOperand(context, inst->getArg(aa));
}
- auto type = inst->getType();
- if( type && type->op == kIROp_VoidType )
+ if (!opHasResult(inst))
{
// This instructions has no type, so don't emit a destination
}
@@ -354,6 +421,7 @@ void generateBytecodeForInst(
}
break;
+#if 0
case kIROp_Func:
{
encodeUInt(context, inst->op);
@@ -368,6 +436,7 @@ void generateBytecodeForInst(
encodeOperand(context, inst);
}
break;
+#endif
case kIROp_Store:
{
@@ -375,9 +444,9 @@ void generateBytecodeForInst(
// We need to encode the type being stored, to make
// our lives easier.
- encodeOperand(context, inst->getArg(2)->getType());
+ encodeOperand(context, inst->getArg(1)->getType());
+ encodeOperand(context, inst->getArg(0));
encodeOperand(context, inst->getArg(1));
- encodeOperand(context, inst->getArg(2));
}
break;
@@ -385,33 +454,166 @@ void generateBytecodeForInst(
{
encodeUInt(context, inst->op);
encodeOperand(context, inst->getType());
- encodeOperand(context, inst->getArg(1));
+ encodeOperand(context, inst->getArg(0));
encodeOperand(context, inst);
}
break;
}
}
-Int getIDForGlobalSymbol(
+BytecodeGenerationPtr<BCType> emitBCType(
+ BytecodeGenerationContext* context,
+ Type* type,
+ IROp op,
+ BytecodeGenerationPtr<uint8_t> const* args,
+ UInt argCount)
+{
+ UInt size = sizeof(BCType)
+ + argCount * sizeof(BCPtr<void>);
+
+ BytecodeGenerationPtr<uint8_t> bcAllocation(
+ context->shared,
+ allocateRaw(context, size, alignof(BCPtr<void>)));
+
+ BytecodeGenerationPtr<BCType> bcType = bcAllocation.bitCast<BCType>();
+ auto bcArgs = (bcType + 1).bitCast<BCPtr<uint8_t>>();
+
+ bcType->op = op;
+ bcType->argCount = argCount;
+
+ for(UInt aa = 0; aa < argCount; ++aa)
+ {
+ bcArgs[aa] = args[aa];
+ }
+
+ UInt id = context->shared->bcTypes.Count();
+ context->shared->mapTypeToID.Add(type, id);
+ context->shared->bcTypes.Add(bcType);
+ bcType->id = id;
+
+ return bcType;
+}
+
+BytecodeGenerationPtr<BCType> emitBCVarArgType(
+ BytecodeGenerationContext* context,
+ Type* type,
+ IROp op,
+ List<BytecodeGenerationPtr<uint8_t>> args)
+{
+ return emitBCType(context, type, op, args.Buffer(), args.Count());
+}
+
+BytecodeGenerationPtr<BCType> emitBCType(
BytecodeGenerationContext* context,
- IRInst* inst)
+ Type* type,
+ IROp op)
+{
+ return emitBCType(context, type, op, nullptr, 0);
+}
+
+BytecodeGenerationPtr<BCType> emitBCType(
+ BytecodeGenerationContext* context,
+ Type* type);
+
+// Emit a `BCType` representation for the given `Type`
+BytecodeGenerationPtr<BCType> emitBCTypeImpl(
+ BytecodeGenerationContext* context,
+ Type* type)
+{
+ // A NULL type is interpreted as equivalent to `Void` for now.
+ if( !type )
+ {
+ return emitBCType(context, type, kIROp_VoidType);
+ }
+
+ if( auto basicType = type->As<BasicExpressionType>() )
+ {
+ switch(basicType->baseType)
+ {
+ case BaseType::Void: return emitBCType(context, type, kIROp_VoidType);
+ case BaseType::Bool: return emitBCType(context, type, kIROp_BoolType);
+ case BaseType::Int: return emitBCType(context, type, kIROp_Int32Type);
+ case BaseType::UInt: return emitBCType(context, type, kIROp_UInt32Type);
+ case BaseType::UInt64: return emitBCType(context, type, kIROp_UInt64Type);
+ case BaseType::Half: return emitBCType(context, type, kIROp_Float16Type);
+ case BaseType::Float: return emitBCType(context, type, kIROp_Float32Type);
+ case BaseType::Double: return emitBCType(context, type, kIROp_Float64Type);
+
+ default:
+ break;
+ }
+ }
+ else if( auto funcType = type->As<FuncType>() )
+ {
+ List<BytecodeGenerationPtr<uint8_t>> operands;
+
+ operands.Add(emitBCType(context, funcType->resultType).bitCast<uint8_t>());
+ UInt paramCount = funcType->getParamCount();
+ for(UInt pp = 0; pp < paramCount; ++pp)
+ {
+ operands.Add(emitBCType(context, funcType->getParamType(pp)).bitCast<uint8_t>());
+ }
+
+ return emitBCVarArgType(context, type, kIROp_FuncType, operands);
+ }
+ else if( auto ptrType = type->As<PtrType>() )
+ {
+ List<BytecodeGenerationPtr<uint8_t>> operands;
+ operands.Add(emitBCType(context, ptrType->getValueType()).bitCast<uint8_t>());
+ return emitBCVarArgType(context, type, kIROp_PtrType, operands);
+ }
+ else if( auto rwStructuredBufferType = type->As<HLSLRWStructuredBufferType>() )
+ {
+ List<BytecodeGenerationPtr<uint8_t>> operands;
+ operands.Add(emitBCType(context, rwStructuredBufferType->elementType).bitCast<uint8_t>());
+ return emitBCVarArgType(context, type, kIROp_readWriteStructuredBufferType, operands);
+ }
+ else if( auto structuredBufferType = type->As<HLSLStructuredBufferType>() )
+ {
+ List<BytecodeGenerationPtr<uint8_t>> operands;
+ operands.Add(emitBCType(context, structuredBufferType->elementType).bitCast<uint8_t>());
+ return emitBCVarArgType(context, type, kIROp_structuredBufferType, operands);
+ }
+
+
+ SLANG_UNEXPECTED("unimplemented");
+ return BytecodeGenerationPtr<BCType>();
+}
+
+BytecodeGenerationPtr<BCType> emitBCType(
+ BytecodeGenerationContext* context,
+ Type* type)
{
- Int globalID;
- if(context->shared->mapGlobalSymbolToGLobalID.TryGetValue(inst, globalID))
- return globalID;
+ auto canonical = type->GetCanonicalType();
+ UInt id = 0;
+ if(context->shared->mapTypeToID.TryGetValue(canonical, id))
+ {
+ return context->shared->bcTypes[id];
+ }
- SLANG_UNEXPECTED("no such ID");
+ BytecodeGenerationPtr<BCType> bcType = emitBCTypeImpl(context, canonical);
+ return bcType;
}
-uint32_t getTypeForGlobalSymbol(
+uint32_t getTypeID(
BytecodeGenerationContext* context,
- IRInst* inst)
+ Type* type)
+{
+ // We have a type, and we need to emit it (if we haven't
+ // already) and return its index in the global type table.
+ BytecodeGenerationPtr<BCType> bcType = emitBCType(context, type);
+ return bcType->id;
+}
+
+uint32_t getTypeIDForGlobalSymbol(
+ BytecodeGenerationContext* context,
+ IRValue* inst)
{
auto type = inst->getType();
if(!type)
return 0;
- return getIDForGlobalSymbol(context, type);
+ return getTypeID(context, type);
}
BytecodeGenerationPtr<char> allocateString(
@@ -442,7 +644,7 @@ BytecodeGenerationPtr<char> allocateString(
BytecodeGenerationPtr<char> tryGenerateNameForSymbol(
BytecodeGenerationContext* context,
- IRInst* inst)
+ IRGlobalValue* inst)
{
// TODO: this is gross, and the IR should probably have
// a more direct means of querying a name for a symbol.
@@ -462,9 +664,10 @@ BytecodeGenerationPtr<char> tryGenerateNameForSymbol(
return BytecodeGenerationPtr<char>();
}
+// Generate a `BCSymbol` that can represent a global value.
BytecodeGenerationPtr<BCSymbol> generateBytecodeSymbolForInst(
BytecodeGenerationContext* context,
- IRInst* inst)
+ IRGlobalValue* inst)
{
switch( inst->op )
{
@@ -474,7 +677,7 @@ BytecodeGenerationPtr<BCSymbol> generateBytecodeSymbolForInst(
BytecodeGenerationPtr<BCFunc> bcFunc = allocate<BCFunc>(context);
bcFunc->op = inst->op;
- bcFunc->typeGlobalID = getTypeForGlobalSymbol(context, inst);
+ bcFunc->typeID = getTypeIDForGlobalSymbol(context, inst);
BytecodeGenerationContext subContextStorage;
BytecodeGenerationContext* subContext = &subContextStorage;
@@ -515,7 +718,17 @@ BytecodeGenerationPtr<BCSymbol> generateBytecodeSymbolForInst(
{
UInt blockID = blockCounter++;
UInt paramCount = 0;
- for( auto ii = bb->firstChild; ii; ii = ii->nextInst )
+
+ for( auto pp = bb->getFirstParam(); pp; pp = pp->getNextParam() )
+ {
+ // A parameter always uses a register.
+ regCounter++;
+ //
+ // We also want to keep a count of the parameters themselves.
+ paramCount++;
+ }
+
+ for( auto ii = bb->getFirstInst(); ii; ii = ii->nextInst )
{
switch( ii->op )
{
@@ -528,15 +741,6 @@ BytecodeGenerationPtr<BCSymbol> generateBytecodeSymbolForInst(
}
break;
- case kIROp_Param:
- // A parameter always uses a register.
- regCounter++;
- //
- // We also want to keep a count of the parameters themselves.
- paramCount++;
- break;
-
-
case kIROp_Var:
// A `var` (`alloca`) node needs two registers:
// one to hold the actual storage, and another
@@ -573,30 +777,25 @@ BytecodeGenerationPtr<BCSymbol> generateBytecodeSymbolForInst(
// are always the first N registers in the overall list.
//
bcBlocks[blockID].params = bcRegs + regCounter;
- for( auto ii = bb->firstChild; ii; ii = ii->nextInst )
+ for( auto pp = bb->getFirstParam(); pp; pp = pp->getNextParam() )
{
- if(ii->op != kIROp_Param)
- continue;
-
Int localID = regCounter++;
- subContext->mapInstToLocalID.Add(ii, localID);
+ subContext->mapInstToLocalID.Add(pp, localID);
- bcRegs[localID].op = ii->op;
- bcRegs[localID].name = tryGenerateNameForSymbol(context, ii);
+ bcRegs[localID].op = pp->op;
+#if 0
+ bcRegs[localID].name = tryGenerateNameForSymbol(context, pp);
+#endif
bcRegs[localID].previousVarIndexPlusOne = localID;
- bcRegs[localID].typeGlobalID = getTypeForGlobalSymbol(context, ii);
+ bcRegs[localID].typeID = getTypeIDForGlobalSymbol(context, pp);
}
// Now loop over the non-parameter instructions and
// allocate actual register locations to them.
- for( auto ii = bb->firstChild; ii; ii = ii->nextInst )
+ for( auto ii = bb->getFirstInst(); ii; ii = ii->nextInst )
{
switch(ii->op)
{
- case kIROp_Param:
- // Already handled.
- break;
-
default:
// For an ordinary instruction with a result,
// allocate it here.
@@ -606,9 +805,11 @@ BytecodeGenerationPtr<BCSymbol> generateBytecodeSymbolForInst(
subContext->mapInstToLocalID.Add(ii, localID);
bcRegs[localID].op = ii->op;
+#if 0
bcRegs[localID].name = tryGenerateNameForSymbol(context, ii);
+#endif
bcRegs[localID].previousVarIndexPlusOne = localID;
- bcRegs[localID].typeGlobalID = getTypeForGlobalSymbol(context, ii);
+ bcRegs[localID].typeID = getTypeIDForGlobalSymbol(context, ii);
}
break;
@@ -624,30 +825,39 @@ BytecodeGenerationPtr<BCSymbol> generateBytecodeSymbolForInst(
subContext->mapInstToLocalID.Add(ii, localID);
bcRegs[localID].op = ii->op;
+#if 0
bcRegs[localID].name = tryGenerateNameForSymbol(context, ii);
+#endif
bcRegs[localID].previousVarIndexPlusOne = localID;
- bcRegs[localID].typeGlobalID = getTypeForGlobalSymbol(context, ii);
+ bcRegs[localID].typeID = getTypeIDForGlobalSymbol(context, ii);
bcRegs[localID+1].op = ii->op;
bcRegs[localID+1].previousVarIndexPlusOne = localID+1;
- bcRegs[localID+1].typeGlobalID = getIDForGlobalSymbol(context,
- ((IRPtrType*) ii->getType())->getValueType());
+ bcRegs[localID+1].typeID = getTypeID(context,
+ (ii->getType()->As<PtrType>())->getValueType());
}
break;
}
}
}
+ assert(regCounter == regCount);
// Now that we've allocated our blocks and our registers
// we can go through the actual process of emitting instructions. Hooray!
blockCounter = 0;
+
+ // Offset of each basic block from the start of the code
+ // for the current funciton.
+ List<UInt> blockOffsets;
for( auto bb = irFunc->getFirstBlock(); bb; bb = bb->getNextBlock() )
{
UInt blockID = blockCounter++;
- bcBlocks[blockID].code = getPtr<BCOp>(context);
+ // Get local bytecode offset for current block.
+ UInt blockOffset = subContext->currentBytecode.Count();
+ blockOffsets.Add( blockOffset );
- for( auto ii = bb->firstChild; ii; ii = ii->nextInst )
+ for( auto ii = bb->getFirstInst(); ii; ii = ii->nextInst )
{
// What we do with each instruction depends a bit on the
// kind of instruction it is.
@@ -671,6 +881,22 @@ BytecodeGenerationPtr<BCSymbol> generateBytecodeSymbolForInst(
}
}
+
+ // We've collected bytecode for the instruction stream
+ // into a sub-context, so we can now append that code.
+ UInt byteCount = subContext->currentBytecode.Count();
+ BytecodeGenerationPtr<uint8_t> bytes = allocateArray<uint8_t>(context, byteCount);
+ memcpy(&bytes[0], subContext->currentBytecode.Buffer(), byteCount);
+
+ // Now that we've allocated the storage, we can write
+ // the bytecode pointers into the blocks.
+ blockCounter = 0;
+ for( auto bb = irFunc->getFirstBlock(); bb; bb = bb->getNextBlock() )
+ {
+ UInt blockID = blockCounter++;
+ bcBlocks[blockID].code = bytes + blockOffsets[blockID];
+ }
+
// Finally, after emitting all the instructions we can
// build a table of global symbols taht need to be
// imported into the current function as constants.
@@ -689,6 +915,17 @@ BytecodeGenerationPtr<BCSymbol> generateBytecodeSymbolForInst(
}
break;
+ case kIROp_global_var:
+ {
+ auto bcVar = allocate<BCSymbol>(context);
+
+ bcVar->op = inst->op;
+ bcVar->typeID = getTypeID(context, inst->type);
+
+ return bcVar;
+ }
+ break;
+
default:
// Most instructions don't need a custom representation.
return BytecodeGenerationPtr<BCSymbol>();
@@ -699,126 +936,98 @@ BytecodeGenerationPtr<BCModule> generateBytecodeForModule(
BytecodeGenerationContext* context,
IRModule* irModule)
{
- // The module will get encoded much like a function,
- // and then that function will be "invoked" to load
- // the module.
+ // A module in the bytecode is mostly just a list of the
+ // global symbols in the module.
//
- auto bcModule = allocate<BCModule>(context);
- bcModule->op = irModule->op;
- bcModule->typeGlobalID = 0;
-
- // The logical function that the module representats
- // will only have a single block, so we can allocate it now.
+ // TODO: we need to be careful and recognize the distinction
+ // between the global symbols in the *AST* module, vs. those
+ // symbols which are effectively global in the *IR* module.
//
- auto bcBlock = allocate<BCBlock>(context);
- bcBlock->paramCount = 0;
- bcBlock->params = BytecodeGenerationPtr<BCReg>();
-
- bcModule->blockCount = 1;
- bcModule->blocks = bcBlock;
-
- // Because the module is the top-most level, there is
- // no need for it to have "constants" that represent
- // values imported from the next outer scope.
+ // We probably need to store these distinctly, since we
+ // need the AST global symbols for reflection, and then
+ // also to reconstruct the AST on load when importing a
+ // serialized module. We then need the global IR symbols
+ // to use when linking, to quickly resolve things without
+ // needing any semantic knowledge of nesting at the AST level.
//
- bcModule->constCount = 0;
- bcModule->consts = BytecodeGenerationPtr<BCConst>();
+ auto bcModule = allocate<BCModule>(context);
// We need to compute how many "registers" to allocate
// for the module, where the registers represent the
// values being computed at the global scope.
- UInt regCount = 0;
- for( auto inst = irModule->firstChild; inst; inst = inst->nextInst )
+ UInt symbolCount = 0;
+ for( auto gv : irModule->globalValues )
{
- if(!opHasResult(inst))
- continue;
-
- Int globalID = Int(regCount++);
+ Int globalID = Int(symbolCount++);
- context->shared->mapGlobalSymbolToGLobalID.Add(inst, globalID);
+ // Ensure that local code inside functions can see these symbols
+ BCConst bcConst;
+ bcConst.flavor = kBCConstFlavor_GlobalSymbol;
+ bcConst.id = globalID;
+ context->shared->mapValueToGlobal.Add(gv, bcConst);
// In the global scope, global IDs are also the local IDs
- context->mapInstToLocalID.Add(inst, globalID);
+ context->mapInstToLocalID.Add(gv, globalID);
}
- auto bcRegs = allocateArray<BCReg>(context, regCount);
+ auto bcSymbols = allocateArray<BCPtr<BCSymbol>>(context, symbolCount);
- bcModule->regCount = regCount;
- bcModule->regs = bcRegs;
+ bcModule->symbolCount = symbolCount;
+ bcModule->symbols = bcSymbols;
- // Now lets walk through and initialize all those bytecode
- // register representations so that we can use them.
- UInt regCounter= 0;
- for( auto inst = irModule->firstChild; inst; inst = inst->nextInst )
+ for( auto gv : irModule->globalValues )
{
- if(!opHasResult(inst))
- continue;
-
- UInt regIndex = *context->mapInstToLocalID.TryGetValue(inst);
+ UInt symbolIndex = *context->mapInstToLocalID.TryGetValue(gv);
- BytecodeGenerationPtr<char> name = tryGenerateNameForSymbol(context, inst);
-
- bcRegs[regIndex].op = inst->op;
- bcRegs[regIndex].name = name;
- bcRegs[regIndex].typeGlobalID = getTypeForGlobalSymbol(context, inst);
- bcRegs[regIndex].previousVarIndexPlusOne = regIndex;
- }
-
- // Some instructions represent "nested" symbols that will need
- // custom handling, and we will represent those here.
- List<BytecodeGenerationPtr<BCSymbol>> nestedSymbols;
- for( auto inst = irModule->firstChild; inst; inst = inst->nextInst )
- {
- UInt regIndex = *context->mapInstToLocalID.TryGetValue(inst);
-
- auto bcSymbol = generateBytecodeSymbolForInst(context, inst);
+ auto bcSymbol = generateBytecodeSymbolForInst(context, gv);
if (!bcSymbol.getPtr())
continue;
- UInt nestedSymbolID = nestedSymbols.Count();
- nestedSymbols.Add(bcSymbol);
-
- context->mapInstToNestedID.Add(inst, nestedSymbolID);
+ auto name = tryGenerateNameForSymbol(context, gv);
+ bcSymbol->name = name;
- bcSymbol->name = bcRegs[regIndex].name;
+ bcSymbols[symbolIndex] = bcSymbol;
}
- auto nestedSymbolCount = nestedSymbols.Count();
- auto bcNestedSymbols = allocateArray<BCPtr<BCSymbol>>(context, nestedSymbolCount);
+ // At this point we should have identified all the literals we need:
+ UInt constantCount = context->shared->constants.Count();
+ auto bcConstants = allocateArray<BCConstant>(context, constantCount);
+ bcModule->constantCount = constantCount;
+ bcModule->constants = bcConstants;
- bcModule->nestedSymbolCount = nestedSymbolCount;
- bcModule->nestedSymbols = bcNestedSymbols;
- for (UInt ii = 0; ii < nestedSymbolCount; ++ii)
+ for(UInt cc = 0; cc < constantCount; ++cc)
{
- bcNestedSymbols[ii] = nestedSymbols[ii];
- }
-
+ auto irConstant = (IRConstant*) context->shared->constants[cc];
+ bcConstants[cc].op = irConstant->op;
+ bcConstants[cc].typeID = getTypeID(context, irConstant->type);
- // Finally, we can go through and emit the actual code for
- // the initialization step.
- bcBlock->code = getPtr<BCOp>(context);
- for( auto inst = irModule->firstChild; inst; inst = inst->nextInst )
- {
- // Generate bytecode for global-scope inst
- generateBytecodeForInst(context, inst);
- }
- // Need to encode a terminator here, just to keep the encoding valid
- encodeUInt(context, kIROp_ReturnVoid);
+ switch(irConstant->op)
+ {
+ case kIROp_IntLit:
+ {
+ auto ptr = allocate<IRIntegerValue>(context);
+ *ptr = irConstant->u.intVal;
+ bcConstants[cc].ptr = ptr.bitCast<uint8_t>();
+ }
+ break;
-#if 0
+ default:
+ break;
+ }
- // Now we can go through and generate the bytecode object
- // that will represent each of these global symbols
+ }
- List<BytecodeGenerationPtr<BCSymbol>> globalSymbols;
+ // At this point we should have collected all the types we need:
+ UInt typeCount = context->shared->bcTypes.Count();
+ auto bcTypes = allocateArray<BCPtr<BCType>>(context, typeCount);
+ bcModule->typeCount = typeCount;
+ bcModule->types = bcTypes;
- for( auto inst = irModule->firstChild; inst; inst = inst->nextInst )
+ for(UInt tt = 0; tt < typeCount; ++tt)
{
- // Generate bytecode for global-scope inst
- auto globalSymbol = generateBytecodeForGlobalSymbol(context, inst);
- globalSymbols.Add(globalSymbol);
+ bcTypes[tt] = context->shared->bcTypes[tt];
}
-#endif
+
return bcModule;
}
@@ -833,10 +1042,6 @@ void generateBytecodeStream(
memcpy(header->magic, "slang\0bc", sizeof(header->magic));
header->version = 0;
- // HACK: ensure that a NULL pointer in an operand field can
- // be encoded.
- context->shared->mapGlobalSymbolToGLobalID.Add(nullptr, -1);
-
header->module = generateBytecodeForModule(context, irModule);
}
diff --git a/source/slang/bytecode.h b/source/slang/bytecode.h
index e073763cf..1ea16406f 100644
--- a/source/slang/bytecode.h
+++ b/source/slang/bytecode.h
@@ -65,6 +65,7 @@ struct BCPtr
}
operator T*() const { return getPtr(); }
+ T* operator->() const { return getPtr(); }
T* getPtr() const
{
@@ -73,9 +74,52 @@ struct BCPtr
}
};
-struct BCType
+// Representation of a "type-level" value in
+// the bytecode fiel. This corresponds to
+// the AST-level notion of a `Val`
+struct BCVal
{
+ // The opcode used to define this value
uint32_t op;
+
+ // The ID of the type within its module
+ uint32_t id;
+};
+
+struct BCType : BCVal
+{
+ // TODO: avoid having to encode this?
+ uint32_t argCount;
+
+ // type-specific operands follow
+
+ //
+
+ BCPtr<BCVal>* getArgs() { return (BCPtr<BCVal>*) (this +1); }
+
+ BCVal* getArg(UInt index) { return getArgs()[index]; }
+};
+
+struct BCPtrType : BCType
+{
+ BCPtr<BCType> valueType;
+};
+
+struct BCFuncType : BCType
+{
+ BCPtr<BCType> resultType;
+ BCPtr<BCType> paramTypes[1];
+
+ BCType* getResultType() { return resultType; }
+
+ UInt getParamCount() { return argCount - 1; }
+ BCType* getParamType(UInt index) { return paramTypes[index]; }
+};
+
+struct BCConstant : BCVal
+{
+ uint32_t typeID;
+ BCPtr<uint8_t> ptr;
};
struct BCSymbol
@@ -84,15 +128,9 @@ struct BCSymbol
// this symbol; used to categorize things
uint32_t op;
- // The type of the symbol is represent
- // as an index into the global-scope symbol
- // list of the module.
- //
- // Note: This currently precludes having
- // a register with a type that is not
- // statically determined, but that is
- // probably okay.
- uint32_t typeGlobalID;
+ // The index (in the module's type table)
+ // of the type of the symbol:
+ uint32_t typeID;
// The name of this symbol (which might
// be a mangled name at some point,
@@ -111,17 +149,18 @@ struct BCReg : BCSymbol
uint32_t previousVarIndexPlusOne;
};
+enum BCConstFlavor
+{
+ kBCConstFlavor_GlobalSymbol,
+ kBCConstFlavor_Constant,
+};
+
struct BCConst
{
- // The ID of the symbol in the global
- // scope that we are trying to refer
- // to.
- //
- // TODO: eventually, if we have general
- // nesting, then this might be the
- // entry in the outer scope that
- // is being referenced.
- uint32_t globalID;
+ // The flavor of bytecode constant we
+ // are dealing with.
+ uint32_t flavor;
+ uint32_t id;
};
struct BCBlock
@@ -153,16 +192,27 @@ struct BCFunc : BCSymbol
// but this would make the encoding less dense.
uint32_t constCount;
BCPtr<BCConst> consts;
-
- // Data for "nested" symbols (e.g., a function
- // nested inside this function).
- uint32_t nestedSymbolCount;
- BCPtr<BCPtr<BCSymbol>> nestedSymbols;
};
-// A module is encoded more or less like a function.
-struct BCModule : BCFunc
+struct BCModule
{
+ // The symbols (functions, global variables, etc.)
+ // that have been declared in the module.
+ uint32_t symbolCount;
+ BCPtr<BCPtr<BCSymbol>> symbols;
+
+ // The types that are used by this module, stored
+ // in a single array so that they can be conveniently
+ // mapped to another representation in one go.
+ //
+ // Instructions in a bytecode instruction sequence
+ // might reference these types by index.
+ uint32_t typeCount;
+ BCPtr<BCPtr<BCType>> types;
+
+ // True compile-time constants go here:
+ uint32_t constantCount;
+ BCPtr<BCConstant> constants;
};
struct BCHeader
diff --git a/source/slang/check.cpp b/source/slang/check.cpp
index 5c1f7380c..73d464d95 100644
--- a/source/slang/check.cpp
+++ b/source/slang/check.cpp
@@ -4546,26 +4546,21 @@ namespace Slang
// if this is still an invoke expression, test arguments passed to inout/out parameter are LValues
if(auto funcType = invoke->FunctionExpr->type->As<FuncType>())
{
- List<RefPtr<ParamDecl>> paramsStorage;
- List<RefPtr<ParamDecl>> * params = nullptr;
- if (auto func = funcType->declRef.getDecl())
+ UInt paramCount = funcType->getParamCount();
+ for (UInt pp = 0; pp < paramCount; ++pp)
{
- paramsStorage = func->GetParameters().ToArray();
- params = &paramsStorage;
- }
- if (params)
- {
- for (UInt i = 0; i < (*params).Count(); i++)
+ auto paramType = funcType->getParamType(pp);
+ if (auto outParamType = paramType->As<OutTypeBase>())
{
- if ((*params)[i]->HasModifier<OutModifier>())
+ if (pp < expr->Arguments.Count()
+ && !expr->Arguments[pp]->type.IsLeftValue)
{
- if (i < expr->Arguments.Count() && expr->Arguments[i]->type->AsBasicType() &&
- !expr->Arguments[i]->type.IsLeftValue)
+ if (!isRewriteMode())
{
- if (!isRewriteMode())
- {
- getSink()->diagnose(expr->Arguments[i], Diagnostics::argumentExpectedLValue, (*params)[i]->getName());
- }
+ getSink()->diagnose(
+ expr->Arguments[pp],
+ Diagnostics::argumentExpectedLValue,
+ pp);
}
}
}
diff --git a/source/slang/compiler.h b/source/slang/compiler.h
index 7ec812d63..1919699a1 100644
--- a/source/slang/compiler.h
+++ b/source/slang/compiler.h
@@ -15,6 +15,7 @@ namespace Slang
struct IncludeHandler;
class CompileRequest;
class ProgramLayout;
+ class PtrType;
enum class CompilerMode
{
@@ -369,6 +370,7 @@ namespace Slang
RefPtr<Type> errorType;
RefPtr<Type> initializerListType;
RefPtr<Type> overloadedType;
+ RefPtr<Type> irBasicBlockType;
Dictionary<int, RefPtr<Type>> builtinTypes;
Dictionary<String, Decl*> magicDecls;
@@ -388,6 +390,12 @@ namespace Slang
Type* getOverloadedType();
Type* getErrorType();
+ // Should not be used in front-end code
+ Type* getIRBasicBlockType();
+
+ // Construct pointer types on-demand
+ RefPtr<PtrType> getPtrType(RefPtr<Type> valueType);
+
SyntaxClass<RefObject> findSyntaxClass(Name* name);
Dictionary<Name*, SyntaxClass<RefObject> > mapNameToSyntaxClass;
diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang
index 2e200f63a..3755c51c7 100644
--- a/source/slang/core.meta.slang
+++ b/source/slang/core.meta.slang
@@ -85,6 +85,17 @@ for (int tt = 0; tt < kBaseTypeCount; ++tt)
sb << "};\n";
}
+// Declare built-in pointer type
+// (eventually we can have the traditional syntax sugar for this)
+
+}}}}
+
+__generic<T>
+__magic_type(PtrType)
+struct Ptr
+{};
+
+${{{{
// Declare vector and matrix types
diff --git a/source/slang/core.meta.slang.h b/source/slang/core.meta.slang.h
index cf2052d3c..864196944 100644
--- a/source/slang/core.meta.slang.h
+++ b/source/slang/core.meta.slang.h
@@ -86,6 +86,18 @@ for (int tt = 0; tt < kBaseTypeCount; ++tt)
sb << "};\n";
}
+// Declare built-in pointer type
+// (eventually we can have the traditional syntax sugar for this)
+
+sb << "\n";
+sb << "\n";
+sb << "__generic<T>\n";
+sb << "__magic_type(PtrType)\n";
+sb << "struct Ptr\n";
+sb << "{};\n";
+sb << "\n";
+sb << "";
+
// Declare vector and matrix types
diff --git a/source/slang/emit.cpp b/source/slang/emit.cpp
index d4c1be706..64ac85bd4 100644
--- a/source/slang/emit.cpp
+++ b/source/slang/emit.cpp
@@ -93,6 +93,10 @@ struct SharedEmitContext
bool needHackSamplerForTexelFetch = false;
ExtensionUsageTracker extensionUsageTracker;
+
+ Dictionary<IRValue*, UInt> mapIRValueToID;
+
+ HashSet<Decl*> irDeclsVisited;
};
struct EmitContext
@@ -1025,6 +1029,9 @@ struct EmitVisitor
UNEXPECTED(GenericDeclRefType);
UNEXPECTED(InitializerListType);
+ UNEXPECTED(IRBasicBlockType);
+ UNEXPECTED(PtrType);
+
#undef UNEXPECTED
void visitNamedExpressionType(NamedExpressionType* type, TypeEmitArg const& arg)
@@ -1231,6 +1238,17 @@ struct EmitVisitor
EmitType(type, SourceLoc(), name, SourceLoc());
}
+ void EmitType(RefPtr<Type> type, String const& name)
+ {
+ // HACK: the rest of the code wants a `Name`,
+ // so we'll create one for a bit...
+ Name tempName;
+ tempName.text = name;
+
+ EmitType(type, SourceLoc(), &tempName, SourceLoc());
+ }
+
+
void EmitType(RefPtr<Type> type)
{
emitTypeImpl(type, nullptr);
@@ -3942,8 +3960,34 @@ emitDeclImpl(decl, nullptr);
// IR-level emit logc
- String getName(IRInst* inst)
+ UInt getID(IRValue* value)
+ {
+ auto& mapIRValueToID = context->shared->mapIRValueToID;
+
+ UInt id = 0;
+ if (mapIRValueToID.TryGetValue(value, id))
+ return id;
+
+ id = mapIRValueToID.Count() + 1;
+ mapIRValueToID.Add(value, id);
+ return id;
+ }
+
+ String getName(IRValue* inst)
{
+ switch(inst->op)
+ {
+ case kIROp_decl_ref:
+ {
+ auto irDeclRef = (IRDeclRef*) inst;
+ return getText(irDeclRef->declRef.GetName());
+ }
+ break;
+
+ default:
+ break;
+ }
+
if(auto decoration = inst->findDecoration<IRHighLevelDeclDecoration>())
{
auto decl = decoration->decl;
@@ -3957,7 +4001,7 @@ emitDeclImpl(decl, nullptr);
StringBuilder sb;
sb << "_S";
- sb << inst->id;
+ sb << getID(inst);
return sb.ProduceString();
}
@@ -4035,7 +4079,7 @@ emitDeclImpl(decl, nullptr);
}
-
+#if 0
void emitIRSimpleType(
EmitContext* context,
IRType* type)
@@ -4153,12 +4197,14 @@ emitDeclImpl(decl, nullptr);
}
}
+#endif
CodeGenTarget getTarget(EmitContext* context)
{
return context->shared->target;
}
+#if 0
void emitGLSLTypePrefix(
EmitContext* context,
IRType* type)
@@ -4194,44 +4240,9 @@ emitDeclImpl(decl, nullptr);
break;
}
}
+#endif
- void emitIRVectorType(
- EmitContext* context,
- IRVectorType* type)
- {
- switch(getTarget(context))
- {
- case CodeGenTarget::GLSL:
- // HLSL style: `<elementTypePrefix>vec<elementCount>`
- // e.g., `ivec4`
- //
- emitGLSLTypePrefix(context, type->getElementType());
- emit("vec");
- emitIRSimpleValue(context, type->getElementCount());
- break;
-
- default:
- // HLSL style: `<elementTypeName><elementCount>`
- // e.g., `int4`
- //
- emitIRSimpleType(context, type->getElementType());
- emitIRSimpleValue(context, type->getElementCount());
- break;
- }
- }
-
- void emitIRMatrixType(
- EmitContext* context,
- IRMatrixType* type)
- {
- // TODO: this is a GLSL-vs-HLSL decision point
-
- emitIRSimpleType(context, type->getElementType());
- emitIRSimpleValue(context, type->getRowCount());
- emit("x");
- emitIRSimpleValue(context, type->getColumnCount());
- }
-
+#if 0
void emitIRType(
EmitContext* context,
IRType* type,
@@ -4287,10 +4298,11 @@ emitDeclImpl(decl, nullptr);
{
emitIRType(context, type, (IRDeclaratorInfo*) nullptr);
}
+#endif
bool shouldFoldIRInstIntoUseSites(
EmitContext* context,
- IRInst* inst)
+ IRValue* inst)
{
// Certain opcodes should always be folded in
switch( inst->op )
@@ -4306,27 +4318,24 @@ emitDeclImpl(decl, nullptr);
return true;
}
- // Certain *types* will usually want to be folded in
+ // Certain *types* will usually want to be folded in,
+ // because they aren't allowed as types for temporary
+ // variables.
auto type = inst->getType();
- switch (type->op)
+ if(type->As<UniformParameterBlockType>())
{
- case kIROp_ConstantBufferType:
- case kIROp_TextureBufferType:
// TODO: we need to be careful here, because
// HLSL shader model 6 allows these as explicit
// types.
return true;
-
- case kIROp_TextureType:
+ }
+ else if(type->As<TextureTypeBase>())
+ {
// GLSL doesn't allow texture/resource types to
// be used as first-class values, so we need
// to fold them into their use sites in all cases
if(getTarget(context) == CodeGenTarget::GLSL)
return true;
- break;
-
- default:
- break;
}
// By default we will *not* fold things into their use sites.
@@ -4335,20 +4344,16 @@ emitDeclImpl(decl, nullptr);
bool isDerefBaseImplicit(
EmitContext* context,
- IRInst* inst)
+ IRValue* inst)
{
auto type = inst->getType();
- switch (type->op)
+
+ if(type->As<UniformParameterBlockType>())
{
- case kIROp_ConstantBufferType:
- case kIROp_TextureBufferType:
// TODO: we need to be careful here, because
// HLSL shader model 6 allows these as explicit
// types.
return true;
-
- default:
- break;
}
return false;
@@ -4358,7 +4363,7 @@ emitDeclImpl(decl, nullptr);
void emitIROperand(
EmitContext* context,
- IRInst* inst)
+ IRValue* inst)
{
if( shouldFoldIRInstIntoUseSites(context, inst) )
{
@@ -4380,8 +4385,8 @@ emitDeclImpl(decl, nullptr);
EmitContext* context,
IRInst* inst)
{
- UInt argCount = inst->argCount - 1;
- IRUse* args = inst->getArgs() + 1;
+ UInt argCount = inst->argCount;
+ IRUse* args = inst->getArgs();
emit("(");
for(UInt aa = 0; aa < argCount; ++aa)
@@ -4392,12 +4397,35 @@ emitDeclImpl(decl, nullptr);
emit(")");
}
+ void emitIRType(
+ EmitContext* context,
+ IRType* type,
+ String const& name)
+ {
+ EmitType(type, name);
+ }
+
+ void emitIRType(
+ EmitContext* context,
+ IRType* type,
+ Name* name)
+ {
+ EmitType(type, name);
+ }
+
+ void emitIRType(
+ EmitContext* context,
+ IRType* type)
+ {
+ EmitType(type);
+ }
+
void emitIRInstResultDecl(
EmitContext* context,
IRInst* inst)
{
auto type = inst->getType();
- if(!type || type->op == kIROp_VoidType)
+ if(!type)
return;
emitIRType(context, type, getName(inst));
@@ -4406,9 +4434,10 @@ emitDeclImpl(decl, nullptr);
void emitIRInstExpr(
EmitContext* context,
- IRInst* inst)
+ IRValue* value)
{
- switch(inst->op)
+ IRInst* inst = (IRInst*) value;
+ switch(value->op)
{
case kIROp_IntLit:
case kIROp_FloatLit:
@@ -4418,13 +4447,13 @@ emitDeclImpl(decl, nullptr);
case kIROp_Construct:
// Simple constructor call
- if( inst->getArgCount() == 2 && getTarget(context) == CodeGenTarget::HLSL)
+ if( inst->getArgCount() == 1 && getTarget(context) == CodeGenTarget::HLSL)
{
// Need to emit as cast for HLSL
emit("(");
emitIRType(context, inst->getType());
emit(") ");
- emitIROperand(context, inst->getArg(1));
+ emitIROperand(context, inst->getArg(0));
}
else
{
@@ -4446,7 +4475,7 @@ emitDeclImpl(decl, nullptr);
emitIRType(context, inst->getType());
}
emit("(");
- emitIROperand(context, inst->getArg(1));
+ emitIROperand(context, inst->getArg(0));
emit(")");
break;
@@ -4480,9 +4509,9 @@ emitDeclImpl(decl, nullptr);
#define CASE(OPCODE, OP) \
case OPCODE: \
- emitIROperand(context, inst->getArg(1)); \
+ emitIROperand(context, inst->getArg(0)); \
emit("" #OP " "); \
- emitIROperand(context, inst->getArg(2)); \
+ emitIROperand(context, inst->getArg(1)); \
break
CASE(kIROp_Add, +);
@@ -4512,7 +4541,7 @@ emitDeclImpl(decl, nullptr);
case kIROp_Not:
{
- if (inst->getType()->op == kIROp_BoolType)
+ if (inst->getType()->Equals(getSession()->getBoolType()))
{
emit("!");
}
@@ -4520,73 +4549,72 @@ emitDeclImpl(decl, nullptr);
{
emit("~");
}
- emitIROperand(context, inst->getArg(1));
+ emitIROperand(context, inst->getArg(0));
}
break;
case kIROp_Sample:
- // argument 0 is the instruction's type
- emitIROperand(context, inst->getArg(1));
+ emitIROperand(context, inst->getArg(0));
emit(".Sample(");
- emitIROperand(context, inst->getArg(2));
+ emitIROperand(context, inst->getArg(1));
emit(", ");
- emitIROperand(context, inst->getArg(3));
+ emitIROperand(context, inst->getArg(2));
emit(")");
break;
case kIROp_SampleGrad:
// argument 0 is the instruction's type
- emitIROperand(context, inst->getArg(1));
+ emitIROperand(context, inst->getArg(0));
emit(".SampleGrad(");
+ emitIROperand(context, inst->getArg(1));
+ emit(", ");
emitIROperand(context, inst->getArg(2));
emit(", ");
emitIROperand(context, inst->getArg(3));
emit(", ");
emitIROperand(context, inst->getArg(4));
- emit(", ");
- emitIROperand(context, inst->getArg(5));
emit(")");
break;
case kIROp_Load:
// TODO: this logic will really only work for a simple variable reference...
- emitIROperand(context, inst->getArg(1));
+ emitIROperand(context, inst->getArg(0));
break;
case kIROp_Store:
// TODO: this logic will really only work for a simple variable reference...
- emitIROperand(context, inst->getArg(1));
+ emitIROperand(context, inst->getArg(0));
emit(" = ");
- emitIROperand(context, inst->getArg(2));
+ emitIROperand(context, inst->getArg(1));
break;
case kIROp_Call:
{
- emitIROperand(context, inst->getArg(1));
+ emitIROperand(context, inst->getArg(0));
emit("(");
- UInt argCount = inst->getArgCount() - 2;
+ UInt argCount = inst->getArgCount() - 1;
for( UInt aa = 0; aa < argCount; ++aa )
{
if(aa != 0) emit(", ");
- emitIROperand(context, inst->getArg(aa + 2));
+ emitIROperand(context, inst->getArg(aa + 1));
}
emit(")");
}
break;
case kIROp_BufferLoad:
- emitIROperand(context, inst->getArg(1));
+ emitIROperand(context, inst->getArg(0));
emit("[");
- emitIROperand(context, inst->getArg(2));
+ emitIROperand(context, inst->getArg(1));
emit("]");
break;
case kIROp_BufferStore:
- emitIROperand(context, inst->getArg(1));
+ emitIROperand(context, inst->getArg(0));
emit("[");
- emitIROperand(context, inst->getArg(2));
+ emitIROperand(context, inst->getArg(1));
emit("] = ");
- emitIROperand(context, inst->getArg(3));
+ emitIROperand(context, inst->getArg(2));
break;
case kIROp_GroupMemoryBarrierWithGroupSync:
@@ -4595,9 +4623,9 @@ emitDeclImpl(decl, nullptr);
case kIROp_getElement:
case kIROp_getElementPtr:
- emitIROperand(context, inst->getArg(1));
+ emitIROperand(context, inst->getArg(0));
emit("[");
- emitIROperand(context, inst->getArg(2));
+ emitIROperand(context, inst->getArg(1));
emit("]");
break;
@@ -4605,9 +4633,9 @@ emitDeclImpl(decl, nullptr);
case kIROp_Mul_Matrix_Vector:
case kIROp_Mul_Matrix_Matrix:
emit("mul(");
- emitIROperand(context, inst->getArg(1));
+ emitIROperand(context, inst->getArg(0));
emit(", ");
- emitIROperand(context, inst->getArg(2));
+ emitIROperand(context, inst->getArg(1));
emit(")");
break;
@@ -4619,7 +4647,7 @@ emitDeclImpl(decl, nullptr);
UInt elementCount = ii->getElementCount();
for (UInt ee = 0; ee < elementCount; ++ee)
{
- IRInst* irElementIndex = ii->getElementIndex(ee);
+ IRValue* irElementIndex = ii->getElementIndex(ee);
assert(irElementIndex->op == kIROp_IntLit);
IRConstant* irConst = (IRConstant*)irElementIndex;
@@ -4658,7 +4686,7 @@ emitDeclImpl(decl, nullptr);
case kIROp_Var:
{
auto ptrType = inst->getType();
- auto valType = ((IRPtrType*)ptrType)->getValueType();
+ auto valType = ((PtrType*)ptrType)->getValueType();
auto name = getName(inst);
emitIRType(context, valType, name);
@@ -4689,14 +4717,14 @@ emitDeclImpl(decl, nullptr);
{
auto ii = (IRSwizzleSet*)inst;
emitIRInstResultDecl(context, inst);
- emitIROperand(context, inst->getArg(1));
+ emitIROperand(context, inst->getArg(0));
emit(";\n");
emitIROperand(context, inst);
emit(".");
UInt elementCount = ii->getElementCount();
for (UInt ee = 0; ee < elementCount; ++ee)
{
- IRInst* irElementIndex = ii->getElementIndex(ee);
+ IRValue* irElementIndex = ii->getElementIndex(ee);
assert(irElementIndex->op == kIROp_IntLit);
IRConstant* irConst = (IRConstant*)irElementIndex;
@@ -4707,34 +4735,16 @@ emitDeclImpl(decl, nullptr);
emit(kComponents[elementIndex]);
}
emit(" = ");
- emitIROperand(context, inst->getArg(2));
+ emitIROperand(context, inst->getArg(1));
emit(";\n");
}
break;
-
-#if 0
- case kIROp_unconditionalBranch:
- emit("// unconditionalBranch ");
- emitIROperand(context, inst->getArg(1));
- emit("\n");
- break;
-
- case kIROp_conditionalBranch:
- emit("// conditionalBranch ");
- emitIROperand(context, inst->getArg(1));
- emit(" ");
- emitIROperand(context, inst->getArg(2));
- emit(" ");
- emitIROperand(context, inst->getArg(3));
- emit("\n");
- break;
-#endif
}
}
void emitIRSemantics(
EmitContext* context,
- IRInst* inst)
+ IRValue* inst)
{
auto decoration = inst->findDecoration<IRHighLevelDeclDecoration>();
if( decoration )
@@ -4745,7 +4755,7 @@ emitDeclImpl(decl, nullptr);
VarLayout* getVarLayout(
EmitContext* context,
- IRInst* var)
+ IRValue* var)
{
auto decoration = var->findDecoration<IRLayoutDecoration>();
if (!decoration)
@@ -4756,7 +4766,7 @@ emitDeclImpl(decl, nullptr);
void emitIRLayoutSemantics(
EmitContext* context,
- IRInst* inst,
+ IRValue* inst,
char const* uniformSemanticSpelling = "register")
{
auto layout = getVarLayout(context, inst);
@@ -4784,9 +4794,9 @@ emitDeclImpl(decl, nullptr);
while(block != end)
{
// Start by emitting the non-terminator instructions in the block.
- auto terminator = block->lastChild;
+ auto terminator = block->getLastInst();
assert(isTerminatorInst(terminator));
- for (auto inst = block->firstChild; inst != terminator; inst = inst->nextInst)
+ for (auto inst = block->getFirstInst(); inst != terminator; inst = inst->nextInst)
{
emitIRInst(context, inst);
}
@@ -5095,10 +5105,8 @@ emitDeclImpl(decl, nullptr);
// encoded as a parameter of pointer type, so
// we need to decode that here.
//
- if( paramType->op == kIROp_PtrType )
+ if( auto ptrType = paramType->As<PtrType>() )
{
- auto ptrType = (IRPtrType*) paramType;
-
// TODO: we need a way to distinguish `out`
// from `inout`. The easiest way to do
// that might be to have each be a distinct
@@ -5243,7 +5251,7 @@ emitDeclImpl(decl, nullptr);
}
}
-
+#if 0
void emitIRStruct(
EmitContext* context,
IRStructDecl* structType)
@@ -5263,6 +5271,7 @@ emitDeclImpl(decl, nullptr);
}
emit("};\n");
}
+#endif
void emitIRVarModifiers(
EmitContext* context,
@@ -5317,9 +5326,9 @@ emitDeclImpl(decl, nullptr);
}
void emitIRParameterBlock(
- EmitContext* context,
- IRVar* varDecl,
- IRUniformBufferType* type)
+ EmitContext* context,
+ IRGlobalVar* varDecl,
+ UniformParameterBlockType* type)
{
emit("cbuffer ");
emit(getName(varDecl));
@@ -5341,17 +5350,15 @@ emitDeclImpl(decl, nullptr);
typeLayout = parameterBlockTypeLayout->elementTypeLayout;
}
- switch( elementType->op )
+ if(auto declRefType = elementType->As<DeclRefType>())
{
- case kIROp_StructType:
+ if(auto structDeclRef = declRefType->declRef.As<StructDecl>())
{
- auto structType = (IRStructDecl*) elementType;
-
auto structTypeLayout = typeLayout.As<StructTypeLayout>();
assert(structTypeLayout);
UInt fieldIndex = 0;
- for(auto ff = structType->getFirstField(); ff; ff = ff->getNextField())
+ for(auto ff : GetFields(structDeclRef))
{
// TODO: need a plan to deal with the case where the IR-level
// `struct` type might not match the high-level type, so that
@@ -5367,19 +5374,18 @@ emitDeclImpl(decl, nullptr);
emitIRVarModifiers(context, fieldLayout);
- auto fieldType = ff->getFieldType();
- emitIRType(context, fieldType, getName(ff));
+ auto fieldType = GetType(ff);
+ emitIRType(context, fieldType, ff.GetName());
emitHLSLParameterBlockFieldLayoutSemantics(layout, fieldLayout);
emit(";\n");
}
}
- break;
-
- default:
+ }
+ else
+ {
emit("/* unexpected */");
- break;
}
emit("}\n");
@@ -5391,8 +5397,9 @@ emitDeclImpl(decl, nullptr);
{
auto allocatedType = varDecl->getType();
auto varType = allocatedType->getValueType();
- auto addressSpace = allocatedType->getAddressSpace();
+// auto addressSpace = allocatedType->getAddressSpace();
+#if 0
switch( varType->op )
{
case kIROp_ConstantBufferType:
@@ -5403,6 +5410,7 @@ emitDeclImpl(decl, nullptr);
default:
break;
}
+#endif
// Need to emit appropriate modifiers here.
@@ -5410,6 +5418,7 @@ emitDeclImpl(decl, nullptr);
emitIRVarModifiers(context, layout);
+#if 0
switch (addressSpace)
{
default:
@@ -5419,6 +5428,51 @@ emitDeclImpl(decl, nullptr);
emit("groupshared ");
break;
}
+#endif
+
+ emitIRType(context, varType, getName(varDecl));
+
+ emitIRSemantics(context, varDecl);
+
+ emitIRLayoutSemantics(context, varDecl);
+
+ emit(";\n");
+ }
+
+ void emitIRGlobalVar(
+ EmitContext* context,
+ IRGlobalVar* varDecl)
+ {
+ auto allocatedType = varDecl->getType();
+ auto varType = allocatedType->getValueType();
+// auto addressSpace = allocatedType->getAddressSpace();
+
+ if (auto paramBlockType = varType->As<UniformParameterBlockType>())
+ {
+ emitIRParameterBlock(
+ context,
+ varDecl,
+ paramBlockType);
+ return;
+ }
+
+ // Need to emit appropriate modifiers here.
+
+ auto layout = getVarLayout(context, varDecl);
+
+ emitIRVarModifiers(context, layout);
+
+#if 0
+ switch (addressSpace)
+ {
+ default:
+ break;
+
+ case kIRAddressSpace_GroupShared:
+ emit("groupshared ");
+ break;
+ }
+#endif
emitIRType(context, varType, getName(varDecl));
@@ -5431,7 +5485,7 @@ emitDeclImpl(decl, nullptr);
void emitIRGlobalInst(
EmitContext* context,
- IRInst* inst)
+ IRGlobalValue* inst)
{
// TODO: need to be able to `switch` on the IR opcode here,
// so there is some work to be done.
@@ -5441,9 +5495,15 @@ emitDeclImpl(decl, nullptr);
emitIRFunc(context, (IRFunc*) inst);
break;
+ case kIROp_global_var:
+ emitIRGlobalVar(context, (IRGlobalVar*) inst);
+ break;
+
+#if 0
case kIROp_StructType:
emitIRStruct(context, (IRStructDecl*) inst);
break;
+#endif
case kIROp_Var:
emitIRVar(context, (IRVar*) inst);
@@ -5454,13 +5514,154 @@ emitDeclImpl(decl, nullptr);
}
}
+ void ensureStructDecl(
+ EmitContext* context,
+ DeclRef<StructDecl> declRef)
+ {
+ // TODO: Eventually need to deal with the case where
+ // we have user-defined generic types.
+ //
+ auto decl = declRef.getDecl();
+
+ if(context->shared->irDeclsVisited.Contains(decl))
+ return;
+
+ context->shared->irDeclsVisited.Add(decl);
+
+ // First emit any types used by fields of this type
+ for( auto ff : GetFields(declRef) )
+ {
+ if(ff.getDecl()->HasModifier<HLSLStaticModifier>())
+ continue;
+
+ auto fieldType = GetType(ff);
+ emitIRUsedType(context, fieldType);
+ }
+
+ Emit("struct ");
+ emit(declRef.GetName());
+ Emit("\n{\n");
+ for( auto ff : GetFields(declRef) )
+ {
+ if(ff.getDecl()->HasModifier<HLSLStaticModifier>())
+ continue;
+
+ auto fieldType = GetType(ff);
+ emitIRType(context, fieldType, ff.GetName());
+
+ EmitSemantics(ff.getDecl());
+
+ emit(";\n");
+ }
+ Emit("};\n");
+ }
+
+ // A type is going to be used by the IR, so
+ // make sure that we have emitted whatever
+ // it needs.
+ void emitIRUsedType(
+ EmitContext* context,
+ Type* type)
+ {
+ if(type->As<BasicExpressionType>())
+ {}
+ else if(type->As<VectorExpressionType>())
+ {}
+ else if(type->As<MatrixExpressionType>())
+ {}
+ else if(auto arrayType = type->As<ArrayExpressionType>())
+ {
+ emitIRUsedType(context, arrayType->baseType);
+ }
+ else if( auto textureType = type->As<TextureTypeBase>() )
+ {
+ emitIRUsedType(context, textureType->elementType);
+ }
+ else if( auto genericType = type->As<BuiltinGenericType>() )
+ {
+ emitIRUsedType(context, genericType->elementType);
+ }
+ else if( auto ptrType = type->As<PtrType>() )
+ {
+ emitIRUsedType(context, ptrType->getValueType());
+ }
+ else if(type->As<SamplerStateType>() )
+ {
+ }
+ else if( auto declRefType = type->As<DeclRefType>() )
+ {
+ auto declRef = declRefType->declRef;
+ auto decl = declRef.getDecl();
+
+ if(decl->HasModifier<BuiltinTypeModifier>()
+ || decl->HasModifier<MagicTypeModifier>())
+ {
+ return;
+ }
+
+ if( auto structDeclRef = declRef.As<StructDecl>() )
+ {
+ //
+ ensureStructDecl(context, structDeclRef);
+ }
+ }
+ else
+ {}
+ }
+
+ void emitIRUsedTypesForValue(
+ EmitContext* context,
+ IRValue* value)
+ {
+ if(!value) return;
+ switch( value->op )
+ {
+ case kIROp_Func:
+ {
+ auto irFunc = (IRFunc*) value;
+ emitIRUsedType(context, irFunc->getResultType());
+ for( auto bb = irFunc->getFirstBlock(); bb; bb = bb->getNextBlock() )
+ {
+ for( auto pp = bb->getFirstParam(); pp; pp = pp->getNextParam() )
+ {
+ emitIRUsedTypesForValue(context, pp);
+ }
+
+ for( auto ii = bb->getFirstInst(); ii; ii = ii->nextInst )
+ {
+ emitIRUsedTypesForValue(context, ii);
+ }
+ }
+ }
+ break;
+
+ default:
+ {
+ emitIRUsedType(context, value->type);
+ }
+ break;
+ }
+ }
+
+ void emitIRUsedTypesForModule(
+ EmitContext* context,
+ IRModule* module)
+ {
+ for (auto gv : module->globalValues)
+ {
+ emitIRUsedTypesForValue(context, gv);
+ }
+ }
+
void emitIRModule(
EmitContext* context,
IRModule* module)
{
- for(auto ii = module->firstChild; ii; ii = ii->nextInst )
+ emitIRUsedTypesForModule(context, module);
+
+ for (auto gv : module->globalValues)
{
- emitIRGlobalInst(context, ii);
+ emitIRGlobalInst(context, gv);
}
}
diff --git a/source/slang/ir-inst-defs.h b/source/slang/ir-inst-defs.h
index 60dc353ac..c11d66571 100644
--- a/source/slang/ir-inst-defs.h
+++ b/source/slang/ir-inst-defs.h
@@ -22,20 +22,20 @@ INST(arrayType, Array, 2, 0)
INST(BoolType, Bool, 0, 0)
-INST(Float16Type, Float16, 0, 0)
-INST(Float32Type, Float32, 0, 0)
-INST(Float64Type, Float64, 0, 0)
+INST(Float16Type, Float16, 0, 0)
+INST(Float32Type, Float32, 0, 0)
+INST(Float64Type, Float64, 0, 0)
// Signed integer types.
// Note that `IntPtr` represents a pointer-sized integer type,
// and will end up being equivalent to either `Int32` or `Int64`
// when it comes time to actually generate code.
//
-INST(Int8Type, Int8, 0, 0)
-INST(Int16Type, Int16, 0, 0)
-INST(Int32Type, Int32, 0, 0)
-INST(IntPtrType, IntPtr, 0, 0)
-INST(Int64Type, Int64, 0, 0)
+INST(Int8Type, Int8, 0, 0)
+INST(Int16Type, Int16, 0, 0)
+INST(Int32Type, Int32, 0, 0)
+INST(IntPtrType, IntPtr, 0, 0)
+INST(Int64Type, Int64, 0, 0)
// Unlike a lot of other IRs, we retain a distinction between
// signed and unsigned integer types, simply because many of
@@ -52,11 +52,11 @@ INST(Int64Type, Int64, 0, 0)
// or else we want to keep using the orignal types, but need
// to cast around any ordinary math operations on signed types.
//
-INST(UInt8Type, Int8, 0, 0)
-INST(UInt16Type, Int16, 0, 0)
-INST(UInt32Type, Int32, 0, 0)
-INST(UIntPtrType, IntPtr, 0, 0)
-INST(UInt64Type, Int64, 0, 0)
+INST(UInt8Type, Int8, 0, 0)
+INST(UInt16Type, Int16, 0, 0)
+INST(UInt32Type, Int32, 0, 0)
+INST(UIntPtrType, IntPtr, 0, 0)
+INST(UInt64Type, Int64, 0, 0)
// A user-defined structure declaration at the IR level.
// Unlike in the AST where there is a distinction between
@@ -94,6 +94,7 @@ INST(GenericParameterType, GenericParameterType, 1, 0)
INST(boolConst, boolConst, 0, 0)
INST(IntLit, integer_constant, 0, 0)
INST(FloatLit, float_constant, 0, 0)
+INST(decl_ref, decl_ref, 0, 0)
INST(Construct, construct, 0, 0)
INST(Call, call, 1, 0)
@@ -102,6 +103,8 @@ INST(Module, module, 0, PARENT)
INST(Func, func, 0, PARENT)
INST(Block, block, 0, PARENT)
+INST(global_var, global_var, 0, 0)
+
INST(Param, param, 0, 0)
INST(StructField, field, 0, 0)
INST(Var, var, 0, 0)
diff --git a/source/slang/ir-insts.h b/source/slang/ir-insts.h
index 1e9638ade..74151adaa 100644
--- a/source/slang/ir-insts.h
+++ b/source/slang/ir-insts.h
@@ -10,6 +10,7 @@
#include "compiler.h"
#include "ir.h"
+#include "syntax.h"
namespace Slang {
@@ -62,65 +63,22 @@ struct IRLoopControlDecoration : IRDecoration
IRLoopControl mode;
};
-struct IRMangledNameDecoration : IRDecoration
-{
- enum { kDecorationOp = kIRDecorationOp_MangledName };
-
- String mangledName;
-};
-
-// Types
-
-struct IRVectorType : IRType
-{
- IRUse elementType;
- IRUse elementCount;
-
- IRType* getElementType() { return (IRType*) elementType.usedValue; }
- IRInst* getElementCount() { return elementCount.usedValue; }
-};
-
-struct IRMatrixType : IRType
-{
- IRUse elementType;
- IRUse rowCount;
- IRUse columnCount;
-
- IRType* getElementType() { return (IRType*) elementType.usedValue; }
- IRInst* getRowCount() { return rowCount.usedValue; }
- IRInst* getColumnCount() { return columnCount.usedValue; }
-};
-
-struct IRArrayType : IRType
-{
- IRUse elementType;
- IRUse elementCount;
-
- IRType* getElementType() { return (IRType*) elementType.usedValue; }
- IRInst* getElementCount() { return elementCount.usedValue; }
-};
+//
-struct IRGenericParameterType : IRType
+// An IR node to represent a reference to an AST-level
+// declaration.
+struct IRDeclRef : IRValue
{
- IRUse index;
+ DeclRefBase declRef;
};
+//
-struct IRFuncType : IRType
-{
- IRUse resultType;
- // parameter tyeps are varargs...
-
- IRType* getResultType() { return (IRType*) resultType.usedValue; }
- UInt getParamCount()
- {
- return getArgCount() - 2;
- }
- IRType* getParamType(UInt index)
- {
- return (IRType*) getArg(2 + index);
- }
-};
+// TODO(tfoley): the IR-level representation of
+// pointers had an "address space" field that the
+// current AST level lacks. This capability was
+// used to represent `groupshared` allocation,
+// so we probably need to find an alternative.
// Address spaces for IR pointers
enum IRAddressSpace : UInt
@@ -132,6 +90,7 @@ enum IRAddressSpace : UInt
kIRAddressSpace_GroupShared,
};
+#if 0
struct IRPtrType : IRType
{
IRUse valueType;
@@ -145,31 +104,7 @@ struct IRPtrType : IRType
((IRConstant*)addressSpace.usedValue)->u.intVal);
}
};
-
-struct IRTextureType : IRType
-{
- IRUse flavor;
- IRUse elementType;
-
- IRIntegerValue getFlavor() { return ((IRConstant*) flavor.usedValue)->u.intVal; }
- IRType* getElementType() { return (IRType*) elementType.usedValue; }
-};
-
-struct IRBufferType : IRType
-{
- IRUse elementType;
- IRType* getElementType() { return (IRType*) elementType.usedValue; }
-};
-
-
-struct IRUniformBufferType : IRType
-{
- IRUse elementType;
- IRType* getElementType() { return (IRType*) elementType.usedValue; }
-};
-
-struct IRConstantBufferType : IRUniformBufferType {};
-struct IRTextureBufferType : IRUniformBufferType {};
+#endif
struct IRCall : IRInst
{
@@ -187,14 +122,13 @@ struct IRStore : IRInst
IRUse val;
};
-struct IRStructField;
struct IRFieldExtract : IRInst
{
IRUse base;
IRUse field;
- IRInst* getBase() { return base.usedValue; }
- IRStructField* getField() { return (IRStructField*) field.usedValue; }
+ IRValue* getBase() { return base.usedValue; }
+ IRValue* getField() { return field.usedValue; }
};
struct IRFieldAddress : IRInst
@@ -202,8 +136,8 @@ struct IRFieldAddress : IRInst
IRUse base;
IRUse field;
- IRInst* getBase() { return base.usedValue; }
- IRStructField* getField() { return (IRStructField*) field.usedValue; }
+ IRValue* getBase() { return base.usedValue; }
+ IRValue* getField() { return field.usedValue; }
};
// Terminators
@@ -215,7 +149,7 @@ struct IRReturnVal : IRReturn
{
IRUse val;
- IRInst* getVal() { return val.usedValue; }
+ IRValue* getVal() { return val.usedValue; }
};
struct IRReturnVoid : IRReturn
@@ -260,7 +194,7 @@ struct IRConditionalBranch : IRTerminatorInst
IRUse trueBlock;
IRUse falseBlock;
- IRInst* getCondition() { return condition.usedValue; }
+ IRValue* getCondition() { return condition.usedValue; }
IRBlock* getTrueBlock() { return (IRBlock*)trueBlock.usedValue; }
IRBlock* getFalseBlock() { return (IRBlock*)falseBlock.usedValue; }
};
@@ -296,12 +230,12 @@ struct IRSwizzle : IRReturn
{
IRUse base;
- IRInst* getBase() { return base.usedValue; }
+ IRValue* getBase() { return base.usedValue; }
UInt getElementCount()
{
return getArgCount() - 2;
}
- IRInst* getElementIndex(UInt index)
+ IRValue* getElementIndex(UInt index)
{
return getArg(index + 2);
}
@@ -312,43 +246,43 @@ struct IRSwizzleSet : IRReturn
IRUse base;
IRUse source;
- IRInst* getBase() { return base.usedValue; }
- IRInst* getSource() { return source.usedValue; }
+ IRValue* getBase() { return base.usedValue; }
+ IRValue* getSource() { return source.usedValue; }
UInt getElementCount()
{
return getArgCount() - 3;
}
- IRInst* getElementIndex(UInt index)
+ IRValue* getElementIndex(UInt index)
{
return getArg(index + 3);
}
};
-// "Parent" Instructions (Declarations)
-
-struct IRStructField : IRInst
+// An IR `var` instruction conceptually represents
+// a stack allocation of some memory.
+struct IRVar : IRInst
{
- IRType* getFieldType() { return (IRType*) type.usedValue; }
-
- IRStructField* getNextField() { return (IRStructField*) nextInst; }
+ PtrType* getType()
+ {
+ return (PtrType*)type.Ptr();
+ }
};
-struct IRStructDecl : IRParentInst
+struct IRGlobalVar : IRGlobalValue
{
- IRStructField* getFirstField() { return (IRStructField*) firstChild; }
- IRStructField* getLastField() { return (IRStructField*) lastChild; }
-};
+ // TODO: should contain information
+ // for use in initializing the variable
+ // (e.g., a reference to a function
+ // that is to be evaluated to provide
+ // the initial value, or a basic block
+ // that defines a DAG of constant
+ // values to use as initial values...)
-
-struct IRVar : IRInst
-{
- IRPtrType* getType()
- {
- return (IRPtrType*)type.usedValue;
- }
+ PtrType* getType() { return type.As<PtrType>(); }
};
+
// Description of an instruction to be used for global value numbering
struct IRInstKey
{
@@ -369,6 +303,13 @@ bool operator==(IRConstantKey const& left, IRConstantKey const& right);
struct SharedIRBuilder
{
+ // The parent compilation session
+ Session* session;
+ Session* getSession()
+ {
+ return session;
+ }
+
// The module that will own all of the IR
IRModule* module;
@@ -381,53 +322,36 @@ struct IRBuilder
// Shared state for all IR builders working on the same module
SharedIRBuilder* shared;
- IRModule* getModule() { return shared->module; }
-
- // The parent instruction to add children to.
- IRParentInst* parentInst;
-
- void addInst(IRParentInst* parent, IRInst* inst);
- void addInst(IRInst* inst);
-
- IRType* getBaseType(BaseType flavor);
- IRType* getBoolType();
- IRType* getVectorType(IRType* elementType, IRValue* elementCount);
- IRType* getMatrixType(
- IRType* elementType,
- IRValue* rowCount,
- IRValue* columnCount);
- IRType* getArrayType(IRType* elementType, IRValue* elementCount);
- IRType* getArrayType(IRType* elementType);
- IRType* getGenericParameterType(UInt index);
-
- IRType* getTypeType();
- IRType* getVoidType();
- IRType* getBlockType();
-
- IRType* getIntrinsicType(
- IROp op,
- UInt argCount,
- IRValue* const* args);
+ Session* getSession()
+ {
+ return shared->getSession();
+ }
- IRStructDecl* createStructType();
- IRStructField* createStructField(IRType* fieldType);
+ IRModule* getModule() { return shared->module; }
- IRType* getFuncType(
- UInt paramCount,
- IRType* const* paramTypes,
- IRType* resultType);
+ // The current function and block being inserted into
+ // (or `null` if we aren't inserting).
+ IRFunc* func = nullptr;
+ IRBlock* block = nullptr;
+ //
+ // TODO: we eventually also want an `IRInst*` for
+ // an instruction to insert before, so that we
+ // can also use the builder to insert inside
+ // an existing block.
- IRType* getPtrType(
- IRType* valueType,
- IRAddressSpace addressSpace);
+ IRFunc* getFunc() { return func; }
+ IRBlock* getBlock() { return block; }
- IRType* getPtrType(
- IRType* valueType);
+ void addInst(IRBlock* block, IRInst* inst);
+ void addInst(IRInst* inst);
IRValue* getBoolValue(bool value);
IRValue* getIntValue(IRType* type, IRIntegerValue value);
IRValue* getFloatValue(IRType* type, IRFloatingPointValue value);
+ IRValue* getDeclRefVal(
+ DeclRefBase const& declRef);
+
IRInst* emitCallInst(
IRType* type,
IRValue* func,
@@ -448,6 +372,8 @@ struct IRBuilder
IRModule* createModule();
IRFunc* createFunc();
+ IRGlobalVar* createGlobalVar(
+ IRType* valueType);
IRBlock* createBlock();
IRBlock* emitBlock();
@@ -472,12 +398,12 @@ struct IRBuilder
IRInst* emitFieldExtract(
IRType* type,
IRValue* base,
- IRStructField* field);
+ IRValue* field);
IRInst* emitFieldAddress(
IRType* type,
IRValue* basePtr,
- IRStructField* field);
+ IRValue* field);
IRInst* emitElementExtract(
IRType* type,
@@ -556,24 +482,24 @@ struct IRBuilder
IRBlock* breakBlock);
IRDecoration* addDecorationImpl(
- IRInst* inst,
+ IRValue* value,
UInt decorationSize,
IRDecorationOp op);
template<typename T>
- T* addDecoration(IRInst* inst, IRDecorationOp op)
+ T* addDecoration(IRValue* value, IRDecorationOp op)
{
- return (T*) addDecorationImpl(inst, sizeof(T), op);
+ return (T*) addDecorationImpl(value, sizeof(T), op);
}
template<typename T>
- T* addDecoration(IRInst* inst)
+ T* addDecoration(IRValue* value)
{
- return (T*) addDecorationImpl(inst, sizeof(T), IRDecorationOp(T::kDecorationOp));
+ return (T*) addDecorationImpl(value, sizeof(T), IRDecorationOp(T::kDecorationOp));
}
- IRHighLevelDeclDecoration* addHighLevelDeclDecoration(IRInst* inst, Decl* decl);
- IRLayoutDecoration* addLayoutDecoration(IRInst* inst, Layout* layout);
+ IRHighLevelDeclDecoration* addHighLevelDeclDecoration(IRValue* value, Decl* decl);
+ IRLayoutDecoration* addLayoutDecoration(IRValue* value, Layout* layout);
};
}
diff --git a/source/slang/ir.cpp b/source/slang/ir.cpp
index 3e12ef1a2..ee28227e3 100644
--- a/source/slang/ir.cpp
+++ b/source/slang/ir.cpp
@@ -40,7 +40,7 @@ namespace Slang
//
- void IRUse::init(IRValue* u, IRValue* v)
+ void IRUse::init(IRInst* u, IRValue* v)
{
user = u;
usedValue = v;
@@ -58,10 +58,17 @@ namespace Slang
IRUse* IRInst::getArgs()
{
- return &type;
+ // We assume that *all* instructions are laid out
+ // in memory such that their arguments come right
+ // after the first `sizeof(IRInst)` bytes.
+ //
+ // TODO: we probably need to be careful and make
+ // this more robust.
+
+ return (IRUse*)(this + 1);
}
- IRDecoration* IRInst::findDecorationImpl(IRDecorationOp decorationOp)
+ IRDecoration* IRValue::findDecorationImpl(IRDecorationOp decorationOp)
{
for( auto dd = firstDecoration; dd; dd = dd->next )
{
@@ -71,13 +78,28 @@ namespace Slang
return nullptr;
}
+ // IRBlock
+
+ void IRBlock::addParam(IRParam* param)
+ {
+ if (auto lp = lastParam)
+ {
+ lp->nextParam = param;
+ param->prevParam = lp;
+ }
+ else
+ {
+ firstParam = param;
+ }
+ lastParam = param;
+ }
+
// IRFunc
IRType* IRFunc::getResultType() { return getType()->getResultType(); }
UInt IRFunc::getParamCount() { return getType()->getParamCount(); }
IRType* IRFunc::getParamType(UInt index) { return getType()->getParamType(index); }
-
IRParam* IRFunc::getFirstParam()
{
auto entryBlock = getFirstBlock();
@@ -86,41 +108,20 @@ namespace Slang
return entryBlock->getFirstParam();
}
- // IRBlock
-
- IRParam* IRBlock::getFirstParam()
- {
- auto firstInst = firstChild;
- if(!firstInst) return nullptr;
-
- if(firstInst->op != kIROp_Param)
- return nullptr;
-
- return (IRParam*) firstInst;
- }
-
-
- // IRParam
-
- IRParam* IRParam::getNextParam()
+ void IRFunc::addBlock(IRBlock* block)
{
- // TODO: this is written as a search because we don't
- // currently do the careful thing and emit parameters
- // before any other members of a block.
- //
- // This should change on the emit side, instead.
+ block->parentFunc = this;
- auto next = nextInst;
-
- for (;;)
+ if (auto lb = lastBlock)
{
- if (!next) return nullptr;
-
- if(next->op == kIROp_Param)
- return (IRParam*) next;
-
- next = next->nextInst;
+ lb->nextBlock = block;
+ block->prevBlock = lb;
+ }
+ else
+ {
+ firstBlock = block;
}
+ lastBlock = block;
}
//
@@ -156,27 +157,27 @@ namespace Slang
//
// Add an instruction to a specific parent
- void IRBuilder::addInst(IRParentInst* parent, IRInst* inst)
+ void IRBuilder::addInst(IRBlock* block, IRInst* inst)
{
- inst->parent = parent;
+ inst->parentBlock = block;
- if (!parent->firstChild)
+ if (!block->firstInst)
{
inst->prevInst = nullptr;
inst->nextInst = nullptr;
- parent->firstChild = inst;
- parent->lastChild = inst;
+ block->firstInst = inst;
+ block->lastInst = inst;
}
else
{
- auto prev = parent->lastChild;
+ auto prev = block->lastInst;
inst->prevInst = prev;
inst->nextInst = nullptr;
prev->nextInst = inst;
- parent->lastChild = inst;
+ block->lastInst = inst;
}
}
@@ -184,19 +185,42 @@ namespace Slang
void IRBuilder::addInst(
IRInst* inst)
{
- auto parent = parentInst;
+ auto parent = block;
if (!parent)
return;
addInst(parent, inst);
}
+ static IRValue* createValueImpl(
+ IRBuilder* builder,
+ UInt size,
+ IROp op,
+ IRType* type)
+ {
+ IRValue* value = (IRValue*) malloc(size);
+ memset(value, 0, size);
+ value->op = op;
+ value->type = type;
+ return value;
+ }
+
+ template<typename T>
+ static T* createValue(
+ IRBuilder* builder,
+ IROp op,
+ IRType* type)
+ {
+ return (T*) createValueImpl(builder, sizeof(T), op, type);
+ }
+
+
// Create an IR instruction/value and initialize it.
//
// In this case `argCount` and `args` represnt the
// arguments *after* the type (which is a mandatory
// argument for all instructions).
- static IRValue* createInstImpl(
+ static IRInst* createInstImpl(
IRBuilder* builder,
UInt size,
IROp op,
@@ -206,26 +230,17 @@ namespace Slang
UInt varArgCount = 0,
IRValue* const* varArgs = nullptr)
{
- IRValue* inst = (IRInst*) malloc(size);
+ IRInst* inst = (IRInst*) malloc(size);
memset(inst, 0, size);
auto module = builder->getModule();
- if (!module || (type && type->op == kIROp_VoidType))
- {
- // Can't or shouldn't assign an ID to this op
- }
- else
- {
- inst->id = ++module->idCounter;
- }
- inst->argCount = fixedArgCount + varArgCount + 1;
+ inst->argCount = fixedArgCount + varArgCount;
inst->op = op;
- auto operand = inst->getArgs();
+ inst->type = type;
- operand->init(inst, type);
- operand++;
+ auto operand = inst->getArgs();
for( UInt aa = 0; aa < fixedArgCount; ++aa )
{
@@ -394,7 +409,7 @@ namespace Slang
bool operator==(IRInstKey const& left, IRInstKey const& right)
{
if(left.inst->op != right.inst->op) return false;
- if(left.inst->parent != right.inst->parent) return false;
+ if(left.inst->parentBlock != right.inst->parentBlock) return false;
if(left.inst->argCount != right.inst->argCount) return false;
auto argCount = left.inst->argCount;
@@ -412,7 +427,7 @@ namespace Slang
int IRInstKey::GetHashCode()
{
auto code = Slang::GetHashCode(inst->op);
- code = combineHash(code, Slang::GetHashCode(inst->parent));
+ code = combineHash(code, Slang::GetHashCode(inst->parentBlock));
code = combineHash(code, Slang::GetHashCode(inst->argCount));
auto argCount = inst->argCount;
@@ -424,233 +439,21 @@ namespace Slang
return code;
}
- static IRParentInst* joinParentInstsForInsertion(
- IRParentInst* left,
- IRParentInst* right)
- {
- // Are they the same? Easy.
- if(left == right) return left;
-
- // Have we already failed to find a location? Then bail.
- if(!left) return nullptr;
- if(!right) return nullptr;
-
- // Is one inst a parent of the other? Pick the child.
- for( auto ll = left; ll; ll = ll->parent )
- {
- // Did we find the right node in the parent list of the left?
- if(ll == right) return left;
- }
- for( auto rr = right; rr; rr = rr->parent )
- {
- // Did we find the left node in the parent list of the right?
- if(rr == left) return right;
- }
-
- // Seems like they are unrelated, so we should play it safe
- return nullptr;
- }
-
-
- static IRInst* findOrEmitInstImpl(
- IRBuilder* builder,
- UInt size,
- IROp op,
- IRType* type,
- UInt fixedArgCount,
- IRValue* const* fixedArgs,
- UInt varArgCount = 0,
- IRValue* const* varArgs = nullptr)
- {
- // First, we need to pick a good insertion point
- // for the instruction, which we do by looking
- // at its operands.
- //
-
- IRParentInst* parent = builder->shared->module;
- if( type )
- {
- parent = joinParentInstsForInsertion(parent, type->parent);
- }
- for( UInt aa = 0; aa < fixedArgCount; ++aa )
- {
- auto arg = fixedArgs[aa];
- parent = joinParentInstsForInsertion(parent, arg->parent);
- }
- for( UInt aa = 0; aa < varArgCount; ++aa )
- {
- auto arg = varArgs[aa];
- parent = joinParentInstsForInsertion(parent, arg->parent);
- }
-
- // If we failed to find a good insertion point, then insert locally.
- if( !parent )
- {
- parent = builder->parentInst;
- }
-
- if( parent->op == kIROp_Func )
- {
- // We are trying to insert into a function, and we should really
- // be inserting into its entry block.
- assert(parent->firstChild);
- parent = (IRBlock*) ((IRFunc*) parent)->firstChild;
- }
-
- // We now know where we want to insert, but there might
- // already be an equivalent instruction in that block.
- //
- // We will check for such an instruction in a slightly hacky
- // way: we will construct a temporary instruction and
- // then use it to look up in a cache of instructions.
-
- IRInst* keyInst = createInstImpl(builder, size, op, type, fixedArgCount, fixedArgs, varArgCount, varArgs);
- keyInst->parent = parent;
-
- IRInstKey key;
- key.inst = keyInst;
-
- IRInst* inst = nullptr;
- if( builder->shared->globalValueNumberingMap.TryGetValue(key, inst) )
- {
- // We found a match, so just use that.
-
- free(keyInst);
- return inst;
- }
-
- // No match, so use our "key" instruction for real.
- inst = keyInst;
-
- builder->shared->globalValueNumberingMap.Add(key, inst);
-
- keyInst->parent = nullptr;
- builder->addInst(parent, inst);
-
- return inst;
- }
-
- template<typename T>
- static T* findOrEmitInst(
- IRBuilder* builder,
- IROp op,
- IRType* type,
- UInt argCount,
- IRValue* const* args)
- {
- return (T*) findOrEmitInstImpl(
- builder,
- sizeof(T),
- op,
- type,
- argCount,
- args);
- }
-
- template<typename T>
- static T* findOrEmitInst(
- IRBuilder* builder,
- IROp op,
- IRType* type,
- UInt fixedArgCount,
- IRValue* const* fixedArgs,
- UInt varArgCount,
- IRValue* const* varArgs)
- {
- return (T*) findOrEmitInstImpl(
- builder,
- sizeof(T) + varArgCount * sizeof(IRUse),
- op,
- type,
- fixedArgCount,
- fixedArgs,
- varArgCount,
- varArgs);
- }
-
- template<typename T>
- static T* findOrEmitInst(
- IRBuilder* builder,
- IROp op,
- IRType* type)
- {
- return (T*) findOrEmitInstImpl(
- builder,
- sizeof(T),
- op,
- type,
- 0,
- nullptr);
- }
-
- template<typename T>
- static T* findOrEmitInst(
- IRBuilder* builder,
- IROp op,
- IRType* type,
- IRInst* arg)
- {
- return (T*) findOrEmitInstImpl(
- builder,
- sizeof(T),
- op,
- type,
- 1,
- &arg);
- }
-
- template<typename T>
- static T* findOrEmitInst(
- IRBuilder* builder,
- IROp op,
- IRType* type,
- IRInst* arg1,
- IRInst* arg2)
- {
- IRInst* args[] = { arg1, arg2 };
- return (T*) findOrEmitInstImpl(
- builder,
- sizeof(T),
- op,
- type,
- 2,
- &args[0]);
- }
-
- template<typename T>
- static T* findOrEmitInst(
- IRBuilder* builder,
- IROp op,
- IRType* type,
- IRInst* arg1,
- IRInst* arg2,
- IRInst* arg3)
- {
- IRInst* args[] = { arg1, arg2, arg3 };
- return (T*) findOrEmitInstImpl(
- builder,
- sizeof(T),
- op,
- type,
- 3,
- &args[0]);
- }
-
//
bool operator==(IRConstantKey const& left, IRConstantKey const& right)
{
- if(left.inst->op != right.inst->op) return false;
- if(left.inst->type.usedValue != right.inst->type.usedValue) return false;
- if(left.inst->u.ptrData[0] != right.inst->u.ptrData[0]) return false;
- if(left.inst->u.ptrData[1] != right.inst->u.ptrData[1]) return false;
+ if(left.inst->op != right.inst->op) return false;
+ if(left.inst->type != right.inst->type) return false;
+ if(left.inst->u.ptrData[0] != right.inst->u.ptrData[0]) return false;
+ if(left.inst->u.ptrData[1] != right.inst->u.ptrData[1]) return false;
return true;
}
int IRConstantKey::GetHashCode()
{
auto code = Slang::GetHashCode(inst->op);
- code = combineHash(code, Slang::GetHashCode(inst->type.usedValue));
+ code = combineHash(code, Slang::GetHashCode(inst->type));
code = combineHash(code, Slang::GetHashCode(inst->u.ptrData[0]));
code = combineHash(code, Slang::GetHashCode(inst->u.ptrData[1]));
return code;
@@ -668,22 +471,20 @@ namespace Slang
// at its operands.
//
- IRParentInst* parent = builder->shared->module;
-
IRConstant keyInst;
memset(&keyInst, 0, sizeof(keyInst));
keyInst.op = op;
- keyInst.type.usedValue = type;
+ keyInst.type = type;
memcpy(&keyInst.u, value, valueSize);
IRConstantKey key;
key.inst = &keyInst;
- IRConstant* inst = nullptr;
- if( builder->shared->constantMap.TryGetValue(key, inst) )
+ IRConstant* irValue = nullptr;
+ if( builder->shared->constantMap.TryGetValue(key, irValue) )
{
// We found a match, so just use that.
- return inst;
+ return irValue;
}
// We now know where we want to insert, but there might
@@ -693,221 +494,25 @@ namespace Slang
// way: we will construct a temporary instruction and
// then use it to look up in a cache of instructions.
- inst = createInst<IRConstant>(builder, op, type);
- memcpy(&inst->u, value, valueSize);
+ irValue = createInst<IRConstant>(builder, op, type);
+ memcpy(&irValue->u, value, valueSize);
- key.inst = inst;
- builder->shared->constantMap.Add(key, inst);
+ key.inst = irValue;
+ builder->shared->constantMap.Add(key, irValue);
- builder->addInst(parent, inst);
-
- return inst;
+ return irValue;
}
//
- static IRType* getBaseTypeImpl(IRBuilder* builder, IROp op)
- {
- auto inst = findOrEmitInst<IRType>(
- builder,
- op,
- builder->getTypeType());
- return inst;
- }
-
- IRType* IRBuilder::getBaseType(BaseType flavor)
- {
- switch( flavor )
- {
- case BaseType::Void: return getVoidType();
-
- case BaseType::Bool: return getBaseTypeImpl(this, kIROp_BoolType);
- case BaseType::Float: return getBaseTypeImpl(this, kIROp_Float32Type);
- case BaseType::Int: return getBaseTypeImpl(this, kIROp_Int32Type);
- case BaseType::UInt: return getBaseTypeImpl(this, kIROp_UInt32Type);
-
- default:
- SLANG_UNEXPECTED("unhandled base type");
- return nullptr;
- }
- }
-
- IRType* IRBuilder::getBoolType()
- {
- return getBaseType(BaseType::Bool);
- }
-
- IRType* IRBuilder::getVectorType(IRType* elementType, IRValue* elementCount)
- {
- return findOrEmitInst<IRVectorType>(
- this,
- kIROp_VectorType,
- getTypeType(),
- elementType,
- elementCount);
- }
-
- IRType* IRBuilder::getMatrixType(
- IRType* elementType,
- IRValue* rowCount,
- IRValue* columnCount)
- {
- return findOrEmitInst<IRMatrixType>(
- this,
- kIROp_MatrixType,
- getTypeType(),
- elementType,
- rowCount,
- columnCount);
- }
-
- IRType* IRBuilder::getArrayType(IRType* elementType, IRValue* elementCount)
- {
- // The client requests an unsized array by passing `nullptr` for
- // the `elementCount`.
- //
- // We currently encode an unsized array as an ordinary array with
- // zero elements. TODO: carefully consider this choice.
- if (!elementCount)
- {
- elementCount = getIntValue(
- getBaseType(BaseType::Int),
- 0);
- }
-
- return findOrEmitInst<IRArrayType>(
- this,
- kIROp_arrayType,
- getTypeType(),
- elementType,
- elementCount);
- }
-
- IRType* IRBuilder::getArrayType(IRType* elementType)
- {
- return getArrayType(elementType, nullptr);
- }
-
- IRType* IRBuilder::getGenericParameterType(UInt index)
- {
- auto indexVal = getIntValue(getBaseType(BaseType::Int), index);
-
- return findOrEmitInst<IRGenericParameterType>(
- this,
- kIROp_GenericParameterType,
- getTypeType(),
- indexVal);
-
- }
-
-
- IRType* IRBuilder::getTypeType()
- {
- return findOrEmitInst<IRType>(
- this,
- kIROp_TypeType,
- nullptr);
- }
-
- IRType* IRBuilder::getVoidType()
- {
- return findOrEmitInst<IRType>(
- this,
- kIROp_VoidType,
- getTypeType());
- }
-
- IRType* IRBuilder::getBlockType()
- {
- return findOrEmitInst<IRType>(
- this,
- kIROp_BlockType,
- getTypeType());
- }
-
- IRType* IRBuilder::getIntrinsicType(
- IROp op,
- UInt argCount,
- IRValue* const* args)
- {
- return findOrEmitInst<IRType>(
- this,
- op,
- getTypeType(),
- 0,
- nullptr,
- argCount,
- args);
- }
-
-
- IRStructDecl* IRBuilder::createStructType()
- {
- return createInst<IRStructDecl>(
- this,
- kIROp_StructType,
- getTypeType());
- }
-
- IRStructField* IRBuilder::createStructField(IRType* fieldType)
- {
- return createInst<IRStructField>(
- this,
- kIROp_StructField,
- fieldType);
- }
-
-
- IRType* IRBuilder::getFuncType(
- UInt paramCount,
- IRType* const* paramTypes,
- IRType* resultType)
- {
- // TODO: need to unique things here!
- auto inst = createInstWithTrailingArgs<IRFuncType>(
- this,
- kIROp_FuncType,
- getTypeType(),
- 1,
- (IRValue* const*) &resultType,
- paramCount,
- (IRValue* const*) paramTypes);
- addInst(inst);
- return inst;
- }
-
- IRType* IRBuilder::getPtrType(
- IRType* valueType,
- IRAddressSpace addressSpace)
- {
- auto uintType = getBaseType(BaseType::UInt);
- auto irAddressSpace = getIntValue(uintType, addressSpace);
-
- auto inst = findOrEmitInst<IRPtrType>(
- this,
- kIROp_PtrType,
- getTypeType(),
- valueType,
- irAddressSpace);
- return inst;
- }
-
-
- IRType* IRBuilder::getPtrType(
- IRType* valueType)
- {
- return getPtrType(valueType, kIRAddressSpace_Default);
- }
-
-
IRValue* IRBuilder::getBoolValue(bool inValue)
{
IRIntegerValue value = inValue;
return findOrEmitConstant(
this,
kIROp_boolConst,
- getBoolType(),
+ getSession()->getBoolType(),
sizeof(value),
&value);
}
@@ -932,6 +537,18 @@ namespace Slang
&value);
}
+ IRValue* IRBuilder::getDeclRefVal(
+ DeclRefBase const& declRef)
+ {
+ // TODO: we should cache these...
+ auto irValue = createInst<IRDeclRef>(
+ this,
+ kIROp_decl_ref,
+ nullptr);
+ irValue->declRef = declRef;
+ return irValue;
+ }
+
IRInst* IRBuilder::emitCallInst(
IRType* type,
IRValue* func,
@@ -983,52 +600,69 @@ namespace Slang
IRModule* IRBuilder::createModule()
{
- return createInst<IRModule>(
- this,
- kIROp_Module,
- nullptr);
+ return new IRModule();
}
IRFunc* IRBuilder::createFunc()
{
- return createInst<IRFunc>(
+ return createValue<IRFunc>(
this,
kIROp_Func,
nullptr);
}
+ IRGlobalVar* IRBuilder::createGlobalVar(
+ IRType* valueType)
+ {
+ auto ptrType = getSession()->getPtrType(valueType);
+ return createValue<IRGlobalVar>(
+ this,
+ kIROp_global_var,
+ ptrType);
+ }
+
IRBlock* IRBuilder::createBlock()
{
- return createInst<IRBlock>(
+ return createValue<IRBlock>(
this,
kIROp_Block,
- getBlockType());
+ getSession()->getIRBasicBlockType());
}
IRBlock* IRBuilder::emitBlock()
{
- auto inst = createBlock();
- addInst(inst);
- return inst;
+ auto bb = createBlock();
+
+ auto f = this->func;
+ if (f)
+ {
+ f->addBlock(bb);
+ this->block = bb;
+ }
+ return bb;
}
IRParam* IRBuilder::emitParam(
IRType* type)
{
- auto inst = createInst<IRParam>(
+ auto param = createValue<IRParam>(
this,
kIROp_Param,
type);
- addInst(inst);
- return inst;
+
+ if (auto bb = block)
+ {
+ bb->addParam(param);
+ }
+ return param;
}
IRVar* IRBuilder::emitVar(
IRType* type,
IRAddressSpace addressSpace)
{
- auto allocatedType = getPtrType(type, addressSpace);
+ auto allocatedType = getSession()->getPtrType(type);
auto inst = createInst<IRVar>(
this,
kIROp_Var,
@@ -1047,14 +681,14 @@ namespace Slang
IRInst* IRBuilder::emitLoad(
IRValue* ptr)
{
- auto ptrType = ptr->getType();
- if( ptrType->op != kIROp_PtrType )
+ auto ptrType = ptr->getType()->As<PtrType>();
+ if( !ptrType )
{
// Bad!
return nullptr;
}
- auto valueType = ((IRPtrType*) ptrType)->getValueType();
+ auto valueType = ptrType->getValueType();
auto inst = createInst<IRLoad>(
this,
@@ -1070,11 +704,10 @@ namespace Slang
IRValue* dstPtr,
IRValue* srcVal)
{
- auto type = getVoidType();
auto inst = createInst<IRStore>(
this,
kIROp_Store,
- type,
+ nullptr,
dstPtr,
srcVal);
@@ -1085,7 +718,7 @@ namespace Slang
IRInst* IRBuilder::emitFieldExtract(
IRType* type,
IRValue* base,
- IRStructField* field)
+ IRValue* field)
{
auto inst = createInst<IRFieldExtract>(
this,
@@ -1101,7 +734,7 @@ namespace Slang
IRInst* IRBuilder::emitFieldAddress(
IRType* type,
IRValue* base,
- IRStructField* field)
+ IRValue* field)
{
auto inst = createInst<IRFieldAddress>(
this,
@@ -1170,7 +803,7 @@ namespace Slang
UInt elementCount,
UInt const* elementIndices)
{
- auto intType = getBaseType(BaseType::Int);
+ auto intType = getSession()->getBuiltinType(BaseType::Int);
IRValue* irElementIndices[4];
for (UInt ii = 0; ii < elementCount; ++ii)
@@ -1212,7 +845,7 @@ namespace Slang
UInt elementCount,
UInt const* elementIndices)
{
- auto intType = getBaseType(BaseType::Int);
+ auto intType = getSession()->getBuiltinType(BaseType::Int);
IRValue* irElementIndices[4];
for (UInt ii = 0; ii < elementCount; ++ii)
@@ -1229,7 +862,7 @@ namespace Slang
auto inst = createInst<IRReturnVal>(
this,
kIROp_ReturnVal,
- getVoidType(),
+ nullptr,
val);
addInst(inst);
return inst;
@@ -1240,7 +873,7 @@ namespace Slang
auto inst = createInst<IRReturnVoid>(
this,
kIROp_ReturnVoid,
- getVoidType());
+ nullptr);
addInst(inst);
return inst;
}
@@ -1251,7 +884,7 @@ namespace Slang
auto inst = createInst<IRUnconditionalBranch>(
this,
kIROp_unconditionalBranch,
- getVoidType(),
+ nullptr,
block);
addInst(inst);
return inst;
@@ -1263,7 +896,7 @@ namespace Slang
auto inst = createInst<IRBreak>(
this,
kIROp_break,
- getVoidType(),
+ nullptr,
target);
addInst(inst);
return inst;
@@ -1275,7 +908,7 @@ namespace Slang
auto inst = createInst<IRContinue>(
this,
kIROp_continue,
- getVoidType(),
+ nullptr,
target);
addInst(inst);
return inst;
@@ -1286,13 +919,13 @@ namespace Slang
IRBlock* breakBlock,
IRBlock* continueBlock)
{
- IRInst* args[] = { target, breakBlock, continueBlock };
+ IRValue* args[] = { target, breakBlock, continueBlock };
UInt argCount = sizeof(args) / sizeof(args[0]);
auto inst = createInst<IRLoop>(
this,
kIROp_loop,
- getVoidType(),
+ nullptr,
argCount,
args);
addInst(inst);
@@ -1304,13 +937,13 @@ namespace Slang
IRBlock* trueBlock,
IRBlock* falseBlock)
{
- IRInst* args[] = { val, trueBlock, falseBlock };
+ IRValue* args[] = { val, trueBlock, falseBlock };
UInt argCount = sizeof(args) / sizeof(args[0]);
auto inst = createInst<IRConditionalBranch>(
this,
kIROp_conditionalBranch,
- getVoidType(),
+ nullptr,
argCount,
args);
addInst(inst);
@@ -1322,13 +955,13 @@ namespace Slang
IRBlock* trueBlock,
IRBlock* afterBlock)
{
- IRInst* args[] = { val, trueBlock, afterBlock };
+ IRValue* args[] = { val, trueBlock, afterBlock };
UInt argCount = sizeof(args) / sizeof(args[0]);
auto inst = createInst<IRIf>(
this,
kIROp_if,
- getVoidType(),
+ nullptr,
argCount,
args);
addInst(inst);
@@ -1341,13 +974,13 @@ namespace Slang
IRBlock* falseBlock,
IRBlock* afterBlock)
{
- IRInst* args[] = { val, trueBlock, falseBlock, afterBlock };
+ IRValue* args[] = { val, trueBlock, falseBlock, afterBlock };
UInt argCount = sizeof(args) / sizeof(args[0]);
auto inst = createInst<IRIfElse>(
this,
kIROp_ifElse,
- getVoidType(),
+ nullptr,
argCount,
args);
addInst(inst);
@@ -1359,13 +992,13 @@ namespace Slang
IRBlock* bodyBlock,
IRBlock* breakBlock)
{
- IRInst* args[] = { val, bodyBlock, breakBlock };
+ IRValue* args[] = { val, bodyBlock, breakBlock };
UInt argCount = sizeof(args) / sizeof(args[0]);
auto inst = createInst<IRLoopTest>(
this,
kIROp_loopTest,
- getVoidType(),
+ nullptr,
argCount,
args);
addInst(inst);
@@ -1373,7 +1006,7 @@ namespace Slang
}
IRDecoration* IRBuilder::addDecorationImpl(
- IRInst* inst,
+ IRValue* inst,
UInt decorationSize,
IRDecorationOp op)
{
@@ -1388,14 +1021,14 @@ namespace Slang
return decoration;
}
- IRHighLevelDeclDecoration* IRBuilder::addHighLevelDeclDecoration(IRInst* inst, Decl* decl)
+ IRHighLevelDeclDecoration* IRBuilder::addHighLevelDeclDecoration(IRValue* inst, Decl* decl)
{
auto decoration = addDecoration<IRHighLevelDeclDecoration>(inst, kIRDecorationOp_HighLevelDecl);
decoration->decl = decl;
return decoration;
}
- IRLayoutDecoration* IRBuilder::addLayoutDecoration(IRInst* inst, Layout* layout)
+ IRLayoutDecoration* IRBuilder::addLayoutDecoration(IRValue* inst, Layout* layout)
{
auto decoration = addDecoration<IRLayoutDecoration>(inst);
decoration->layout = layout;
@@ -1409,6 +1042,9 @@ namespace Slang
{
StringBuilder* builder;
int indent;
+
+ UInt idCounter = 1;
+ Dictionary<IRValue*, UInt> mapValueToID;
};
static void dump(
@@ -1454,27 +1090,59 @@ namespace Slang
}
}
+ bool opHasResult(IRValue* inst);
+
+ static UInt getID(
+ IRDumpContext* context,
+ IRValue* value)
+ {
+ UInt id = 0;
+ if (context->mapValueToID.TryGetValue(value, id))
+ return id;
+
+ if (opHasResult(value))
+ {
+ id = context->idCounter++;
+ }
+
+ context->mapValueToID.Add(value, id);
+ return id;
+ }
+
static void dumpID(
IRDumpContext* context,
- IRInst* inst)
+ IRValue* inst)
{
if (!inst)
{
dump(context, "<null>");
+ return;
}
- else if( auto mangled = inst->findDecoration<IRMangledNameDecoration>() )
- {
- dump(context, "@");
- dump(context, mangled->mangledName.Buffer());
- }
- else if(inst->id)
- {
- dump(context, "%");
- dump(context, (UInt) inst->id);
- }
- else
+
+ switch(inst->op)
{
- dump(context, "_");
+ case kIROp_Func:
+ {
+ auto irFunc = (IRFunc*) inst;
+ dump(context, "@");
+ dump(context, irFunc->mangledName.Buffer());
+ }
+ break;
+
+ default:
+ {
+ UInt id = getID(context, inst);
+ if (id)
+ {
+ dump(context, "%");
+ dump(context, id);
+ }
+ else
+ {
+ dump(context, "_");
+ }
+ }
+ break;
}
}
@@ -1484,8 +1152,9 @@ namespace Slang
static void dumpOperand(
IRDumpContext* context,
- IRInst* inst)
+ IRValue* inst)
{
+ // TODO: we should have a dedicated value for the `undef` case
if (!inst)
{
dump(context, "undef");
@@ -1514,22 +1183,90 @@ namespace Slang
break;
}
- auto type = inst->getType();
- if (type)
+ dumpID(context, inst);
+ }
+
+ static void dump(
+ IRDumpContext* context,
+ Name* name)
+ {
+ dump(context, getText(name).Buffer());
+ }
+
+
+ static void dumpDeclRef(
+ IRDumpContext* context,
+ DeclRef<Decl> const& declRef);
+
+ static void dumpVal(
+ IRDumpContext* context,
+ Val* val)
+ {
+ if(auto type = dynamic_cast<Type*>(val))
{
- switch (type->op)
- {
- case kIROp_TypeType:
- dumpType(context, (IRType*)inst);
- return;
+ dumpType(context, type);
+ }
+ else if(auto constIntVal = dynamic_cast<ConstantIntVal*>(val))
+ {
+ dump(context, constIntVal->value);
+ }
+ else if(auto genericParamVal = dynamic_cast<GenericParamIntVal*>(val))
+ {
+ dumpDeclRef(context, genericParamVal->declRef);
+ }
+ else
+ {
+ dump(context, "???");
+ }
+ }
- default:
- break;
- }
+ static void dumpDeclRef(
+ IRDumpContext* context,
+ DeclRef<Decl> const& declRef)
+ {
+ auto decl = declRef.getDecl();
+
+ auto parentDeclRef = declRef.GetParent();
+ auto genericParentDeclRef = parentDeclRef.As<GenericDecl>();
+ if(genericParentDeclRef)
+ {
+ parentDeclRef = genericParentDeclRef.GetParent();
}
+ if(parentDeclRef.As<ModuleDecl>())
+ {
+ parentDeclRef = DeclRef<ContainerDecl>();
+ }
- dumpID(context, inst);
+ if(parentDeclRef)
+ {
+ dumpDeclRef(context, parentDeclRef);
+ dump(context, ".");
+ }
+ dump(context, decl->getName());
+
+ if(genericParentDeclRef)
+ {
+ auto subst = declRef.substitutions;
+ if( !subst || subst->genericDecl != genericParentDeclRef.getDecl() )
+ {
+ // No actual substitutions in place here
+ dump(context, "<>");
+ }
+ else
+ {
+ auto args = subst->args;
+ bool first = true;
+ dump(context, "<");
+ for(auto aa : args)
+ {
+ if(!first) dump(context, ",");
+ dumpVal(context, aa);
+ first = false;
+ }
+ dump(context, ">");
+ }
+ }
}
static void dumpType(
@@ -1538,10 +1275,43 @@ namespace Slang
{
if (!type)
{
- dumpID(context, type);
+ dump(context, "_");
return;
}
+ if(auto funcType = type->As<FuncType>())
+ {
+ UInt paramCount = funcType->getParamCount();
+ dump(context, "(");
+ for( UInt pp = 0; pp < paramCount; ++pp )
+ {
+ if(pp != 0) dump(context, ", ");
+ dumpType(context, funcType->getParamType(pp));
+ }
+ dump(context, ") -> ");
+ dumpType(context, funcType->getResultType());
+ }
+ else if(auto arrayType = type->As<ArrayExpressionType>())
+ {
+ dumpType(context, arrayType->baseType);
+ dump(context, "[");
+ if(auto elementCount = arrayType->ArrayLength)
+ {
+ dumpVal(context, elementCount);
+ }
+ dump(context, "]");
+ }
+ else if(auto declRefType = type->As<DeclRefType>())
+ {
+ dumpDeclRef(context, declRefType->declRef);
+ }
+ else
+ {
+ // Need a default case here
+ dump(context, "???");
+ }
+
+#if 0
auto op = type->op;
auto opInfo = kIROpInfos[op];
@@ -1551,21 +1321,6 @@ namespace Slang
dumpID(context, type);
break;
- case kIROp_FuncType:
- {
- auto funcType = (IRFuncType*) type;
- UInt paramCount = funcType->getParamCount();
- dump(context, "(");
- for( UInt pp = 0; pp < paramCount; ++pp )
- {
- if(pp != 0) dump(context, ", ");
- dumpType(context, funcType->getParamType(pp));
- }
- dump(context, ") -> ");
- dumpType(context, funcType->getResultType());
- }
- break;
-
default:
{
dump(context, opInfo.name);
@@ -1585,6 +1340,7 @@ namespace Slang
}
break;
}
+#endif
}
static void dumpInstTypeClause(
@@ -1602,32 +1358,68 @@ namespace Slang
static void dumpChildrenRaw(
IRDumpContext* context,
- IRParentInst* parent)
+ IRBlock* block)
{
- for (auto ii = parent->firstChild; ii; ii = ii->nextInst)
+ for (auto ii = block->firstInst; ii; ii = ii->nextInst)
{
dumpInst(context, ii);
}
}
- static void dumpChildren(
+ static void dumpBlock(
IRDumpContext* context,
- IRInst* inst)
+ IRBlock* block)
{
- auto op = inst->op;
- auto opInfo = &kIROpInfos[op];
- if (opInfo->flags & kIROpFlag_Parent)
+ context->indent--;
+ dump(context, "block ");
+ dumpID(context, block);
+
+ if( block->getFirstParam() )
{
- dumpIndent(context);
- dump(context, "{\n");
- context->indent++;
- auto parent = (IRParentInst*)inst;
- dumpChildrenRaw(context, parent);
- context->indent--;
- dumpIndent(context);
- dump(context, "}\n");
+ dump(context, "(\n");
+ context->indent += 2;
+ for (auto pp = block->getFirstParam(); pp; pp = pp->getNextParam())
+ {
+ if (pp != block->getFirstParam())
+ dump(context, ",\n");
+
+ dumpIndent(context);
+ dump(context, "param ");
+ dumpID(context, pp);
+ dumpInstTypeClause(context, pp->getType());
+ }
+ context->indent -= 2;
+ dump(context, ")");
+ }
+ dump(context, ":\n");
+ context->indent++;
+
+ dumpChildrenRaw(context, block);
+ }
+
+ static void dumpChildrenRaw(
+ IRDumpContext* context,
+ IRFunc* func)
+ {
+ for (auto bb = func->getFirstBlock(); bb; bb = bb->getNextBlock())
+ {
+ dumpBlock(context, bb);
}
}
+
+ static void dumpChildren(
+ IRDumpContext* context,
+ IRFunc* func)
+ {
+ dumpIndent(context);
+ dump(context, "{\n");
+ context->indent++;
+ dumpChildrenRaw(context, func);
+ context->indent--;
+ dumpIndent(context);
+ dump(context, "}\n");
+ }
+
static void dumpInst(
IRDumpContext* context,
IRInst* inst)
@@ -1646,6 +1438,7 @@ namespace Slang
//
switch (op)
{
+#if 0
case kIROp_Module:
dumpIndent(context);
dump(context, "module\n");
@@ -1717,11 +1510,13 @@ namespace Slang
dumpChildrenRaw(context, block);
}
return;
+#endif
default:
break;
}
+#if 0
// We also want to special-case based on the *type*
// of the instruction
auto type = inst->getType();
@@ -1738,38 +1533,42 @@ namespace Slang
return;
}
}
+#endif
// Okay, we have a seemingly "ordinary" op now
dumpIndent(context);
auto opInfo = &kIROpInfos[op];
+ auto type = inst->getType();
- if (type && type->op == kIROp_TypeType)
- {
- dump(context, "type ");
- dumpID(context, inst);
- dump(context, "\t= ");
- }
- else if (type && type->op == kIROp_VoidType)
+ if (!type)
{
+ // No result, okay...
}
else
{
- dump(context, "let ");
- dumpID(context, inst);
- dumpInstTypeClause(context, type);
- dump(context, "\t= ");
+ auto basicType = type->As<BasicExpressionType>();
+ if (basicType && basicType->baseType == BaseType::Void)
+ {
+ // No result, okay...
+ }
+ else
+ {
+ dump(context, "let ");
+ dumpID(context, inst);
+ dumpInstTypeClause(context, type);
+ dump(context, "\t= ");
+ }
}
-
dump(context, opInfo->name);
uint32_t argCount = inst->argCount;
dump(context, "(");
- for (uint32_t ii = 1; ii < argCount; ++ii)
+ for (uint32_t ii = 0; ii < argCount; ++ii)
{
- if (ii != 1)
+ if (ii != 0)
dump(context, ", ");
auto argVal = inst->getArgs()[ii].usedValue;
@@ -1779,10 +1578,78 @@ namespace Slang
dump(context, ")");
dump(context, "\n");
+ }
+
+ void dumpIRFunc(
+ IRDumpContext* context,
+ IRFunc* func)
+ {
+ dump(context, "\n");
+ dumpIndent(context);
+ dump(context, "ir_func ");
+ dumpID(context, func);
+ dumpInstTypeClause(context, func->getType());
+ dump(context, "\n");
+
+ dumpIndent(context);
+ dump(context, "{\n");
+ context->indent++;
- // The instruction might have children,
- // so we need to handle those here
- dumpChildren(context, inst);
+ for (auto bb = func->getFirstBlock(); bb; bb = bb->getNextBlock())
+ {
+ if (bb != func->getFirstBlock())
+ dump(context, "\n");
+ dumpBlock(context, bb);
+ }
+
+ context->indent--;
+ dump(context, "}\n");
+ }
+
+ void dumpIRGlobalVar(
+ IRDumpContext* context,
+ IRGlobalVar* var)
+ {
+ dump(context, "\n");
+ dumpIndent(context);
+ dump(context, "ir_global_var ");
+ dumpID(context, var);
+ dumpInstTypeClause(context, var->getType());
+
+ // TODO: deal with the case where a global
+ // might have embedded initialization logic.
+
+ dump(context, ";\n");
+ }
+
+ void dumpIRGlobalValue(
+ IRDumpContext* context,
+ IRGlobalValue* value)
+ {
+ switch (value->op)
+ {
+ case kIROp_Func:
+ dumpIRFunc(context, (IRFunc*)value);
+ break;
+
+ case kIROp_global_var:
+ dumpIRGlobalVar(context, (IRGlobalVar*)value);
+ break;
+
+ default:
+ dump(context, "???\n");
+ break;
+ }
+ }
+
+ void dumpIRModule(
+ IRDumpContext* context,
+ IRModule* module)
+ {
+ for (auto gv : module->globalValues)
+ {
+ dumpIRGlobalValue(context, gv);
+ }
}
void printSlangIRAssembly(StringBuilder& builder, IRModule* module)
@@ -1791,7 +1658,7 @@ namespace Slang
context.builder = &builder;
context.indent = 0;
- dumpChildrenRaw(&context, module);
+ dumpIRModule(&context, module);
}
String getSlangIRAssembly(IRModule* module)
diff --git a/source/slang/ir.h b/source/slang/ir.h
index 9c3124478..7e4dc8f83 100644
--- a/source/slang/ir.h
+++ b/source/slang/ir.h
@@ -11,36 +11,14 @@
namespace Slang {
-// TODO(tfoley): We should ditch this enumeration
-// and just use the IR opcodes that represent these
-// types directly. The one major complication there
-// is that the order of the enum values currently
-// matters, since it determines promotion rank.
-// We either need to keep that restriction, or
-// look up promotion rank by some other means.
-//
-enum class BaseType
-{
- // Note(tfoley): These are ordered in terms of promotion rank, so be vareful when messing with this
-
- Void = 0,
- Bool,
- Int,
- UInt,
- UInt64,
- Half,
- Float,
- Double,
-};
-
+class FuncType;
+class Layout;
+class Type;
-class Layout;
-
-struct IRFunc;
-struct IRInst;
-struct IRModule;
-struct IRParentInst;
-struct IRType;
+struct IRFunc;
+struct IRInst;
+struct IRModule;
+struct IRValue;
typedef unsigned int IROpFlags;
enum : IROpFlags
@@ -75,34 +53,6 @@ enum IROp : int16_t
};
-#if 0
-enum IRPseudoOp
-{
- kIRPseudoOp_Pos = -1000,
- kIRPseudoOp_PreInc,
- kIRPseudoOp_PreDec,
- kIRPseudoOp_PostInc,
- kIRPseudoOp_PostDec,
- kIRPseudoOp_Sequence,
- kIRPseudoOp_AddAssign,
- kIRPseudoOp_SubAssign,
- kIRPseudoOp_MulAssign,
- kIRPseudoOp_DivAssign,
- kIRPseudoOp_ModAssign,
- kIRPseudoOp_AndAssign,
- kIRPseudoOp_OrAssign,
- kIRPseudoOp_XorAssign ,
- kIRPseudoOp_LshAssign,
- kIRPseudoOp_RshAssign,
- kIRPseudoOp_Assign,
- kIRPseudoOp_BitNot,
- kIRPseudoOp_And,
- kIRPseudoOp_Or,
-
- kIROp_Invalid = -1,
-};
-#endif
-
IROp findIROp(char const* name);
// A logical operation/opcode in the IR
@@ -125,12 +75,12 @@ IROpInfo getIROpInfo(IROp op);
// A use of another value/inst within an IR operation
struct IRUse
{
+ // The value that is being used
+ IRValue* usedValue;
+
// The value that is doing the using.
IRInst* user;
- // The value that is being used
- IRInst* usedValue;
-
// The next use of the same value
IRUse* nextUse;
@@ -138,7 +88,7 @@ struct IRUse
// so that we can simplify updates.
IRUse** prevLink;
- void init(IRInst* user, IRInst* usedValue);
+ void init(IRInst* user, IRValue* usedValue);
};
enum IRDecorationOp : uint16_t
@@ -162,52 +112,61 @@ struct IRDecoration
IRDecorationOp op;
};
-typedef uint32_t IRInstID;
+// Use AST-level types directly to represent the
+// types of IR instructions/values
+typedef Type IRType;
+
+struct IRBlock;
-// In the IR, almost *everything* is an instruction,
-// in order to make the representation as uniform as possible.
-struct IRInst
+// Base class for values in the IR
+struct IRValue
{
// The operation that this value represents
IROp op;
- // A unique ID to represent the op when printing
- // (or zero to indicate that the value of this
- // op isn't special).
- IRInstID id;
+ // The type of the result value of this instruction,
+ // or `null` to indicate that the instruction has
+ // no value.
+ RefPtr<Type> type;
- // The total number of arguments of this instruction
- // (including the type)
- uint32_t argCount;
-
- // The parent of this instruction.
- // This will often be a basic block, but we
- // allow instructions to nest in more general ways.
- IRParentInst* parent;
-
- // The next and previous instructions in the same parent block
- IRInst* nextInst;
- IRInst* prevInst;
+ Type* getType() { return type; }
- // The first use of this value (start of a linked list)
- IRUse* firstUse;
-
- // The linked list of decorations attached to this instruction
+ // The linked list of decorations attached to this value
IRDecoration* firstDecoration;
+ // Look up a decoration in the list of decorations
IRDecoration* findDecorationImpl(IRDecorationOp op);
-
template<typename T>
T* findDecoration()
{
return (T*) findDecorationImpl(IRDecorationOp(T::kDecorationOp));
}
+ // The first use of this value (start of a linked list)
+ IRUse* firstUse;
- // The type of this value
- IRUse type;
- IRType* getType() { return (IRType*) type.usedValue; }
+};
+
+// Instructions are values that can be executed,
+// and which take other values as operands
+struct IRInst : IRValue
+{
+ // The total number of arguments of this instruction.
+ //
+ // TODO: We shouldn't need to allocate this on
+ // all instructions. Instead we should have
+ // instructions that need "vararg" support to
+ // allocate this field ahead of the `this`
+ // pointer.
+ uint32_t argCount;
+
+ // The basic block that contains this instruction.
+ IRBlock* parentBlock;
+
+ // The next and previous instructions in the same parent block
+ IRInst* nextInst;
+ IRInst* prevInst;
UInt getArgCount()
{
@@ -216,20 +175,16 @@ struct IRInst
IRUse* getArgs();
- IRInst* getArg(UInt index)
+ IRValue* getArg(UInt index)
{
return getArgs()[index].usedValue;
}
};
-// This type alias exists because I waffled on the name for a bit.
-// All existing uses of `IRValue` should move to `IRInst`
-typedef IRInst IRValue;
-
typedef int64_t IRIntegerValue;
typedef double IRFloatingPointValue;
-struct IRConstant : IRInst
+struct IRConstant : IRValue
{
union
{
@@ -241,16 +196,6 @@ struct IRConstant : IRInst
} u;
};
-// Representation of a type at the IR level.
-// Such a type may not correspond to the high-level-language notion
-// of a type as used by the front end.
-//
-// Note that types are instructions in the IR, so that operations
-// may take type operands as easily as values.
-struct IRType : IRInst
-{
-};
-
// A instruction that ends a basic block (usually because of control flow)
struct IRTerminatorInst : IRInst
{};
@@ -258,23 +203,20 @@ struct IRTerminatorInst : IRInst
bool isTerminatorInst(IROp op);
bool isTerminatorInst(IRInst* inst);
-
-// A parent instruction contains a sequence of other instructions
+// A function parameter is owned by a basic block, and represents
+// either an incoming function parameter (in the entry block), or
+// a value that flows from one SSA block to another (in a non-entry
+// block).
//
-struct IRParentInst : IRInst
+// In each case, the basic idea is that a block is a "label with
+// arguments."
+struct IRParam : IRValue
{
- // The first and last instruction in the container (or NULL in
- // the case that the container is empty).
- //
- IRInst* firstChild;
- IRInst* lastChild;
-};
+ IRParam* nextParam;
+ IRParam* prevParam;
-// A function parameter is represented by an instruction
-// in the entry block of a function.
-struct IRParam : IRInst
-{
- IRParam* getNextParam();
+ IRParam* getNextParam() { return nextParam; }
+ IRParam* getPrevParam() { return prevParam; }
};
// A basic block is a parent instruction that adds the constraint
@@ -282,48 +224,97 @@ struct IRParam : IRInst
// no function declarations, or nested blocks). We also expect
// that the previous/next instruction are always a basic block.
//
-struct IRBlock : IRParentInst
+struct IRBlock : IRValue
{
+ // Linked list of the instructions contained in this block
+ //
// Note that in a valid program, every block must end with
// a "terminator" instruction, so these should be non-NULL,
- // and `last` should actually be an `IRTerminatorInst`.
+ // and `lastInst` should actually be an `IRTerminatorInst`.
+ IRInst* firstInst;
+ IRInst* lastInst;
- IRBlock* getPrevBlock() { return (IRBlock*) prevInst; }
- IRBlock* getNextBlock() { return (IRBlock*) nextInst; }
+ IRInst* getFirstInst() { return firstInst; }
+ IRInst* getLastInst() { return lastInst; }
- IRFunc* getParent() { return (IRFunc*)parent; }
+ // Links for the list of basic blocks in the parent function
+ IRBlock* prevBlock;
+ IRBlock* nextBlock;
+
+ IRBlock* getPrevBlock() { return prevBlock; }
+ IRBlock* getNextBlock() { return nextBlock; }
+
+ // Linked list of parameters of this block
+ IRParam* firstParam;
+ IRParam* lastParam;
+
+ IRParam* getFirstParam() { return firstParam; }
+ IRParam* getLastParam() { return lastParam; }
+ void addParam(IRParam* param);
+
+ // The parent function that contains this block
+ IRFunc* parentFunc;
+
+ IRFunc* getParent() { return parentFunc; }
- IRParam* getFirstParam();
};
-struct IRFuncType;
+// For right now, we will represent the type of
+// an IR function using the type of the AST
+// function from which it was created.
+//
+// TODO: need to do this better.
+typedef FuncType IRFuncType;
+
+struct IRGlobalValue : IRValue
+{};
// A function is a parent to zero or more blocks of instructions.
//
// A function is itself a value, so that it can be a direct operand of
// an instruction (e.g., a call).
-struct IRFunc : IRParentInst
+struct IRFunc : IRGlobalValue
{
- IRFuncType* getType() { return (IRFuncType*) type.usedValue; }
+ // The type of the IR-level function
+ IRFuncType* getType() { return (IRFuncType*) type.Ptr(); }
+
+ // The mangled name, for a function
+ // that should have linkage.
+ String mangledName;
- IRType* getResultType();
+ // Convenience accessors for working with the
+ // function's type.
+ Type* getResultType();
UInt getParamCount();
- IRType* getParamType(UInt index);
+ Type* getParamType(UInt index);
- IRBlock* getFirstBlock() { return (IRBlock*) firstChild; }
- IRBlock* getLastBlock() { return (IRBlock*) lastChild; }
+ // The list of basic blocks in this function
+ IRBlock* firstBlock = nullptr;
+ IRBlock* lastBlock = nullptr;
+ IRBlock* getFirstBlock() { return firstBlock; }
+ IRBlock* getLastBlock() { return lastBlock; }
+
+ // Add a block to the end of this function.
+ void addBlock(IRBlock* block);
+
+ // Convenience accessor for the IR parameters,
+ // which are actually the parameters of the first
+ // block.
IRParam* getFirstParam();
};
// A module is a parent to functions, global variables, types, etc.
-struct IRModule : IRParentInst
+struct IRModule : RefObject
{
// The designated entry-point function, if any
IRFunc* entryPoint;
- // A special counter used to assign logical ids to instructions in this module.
- IRInstID idCounter;
+ // A list of all the functions and other
+ // global values declared in this module.
+ List<IRGlobalValue*> globalValues;
+
+ // TODO: need a symbol of all the global variables too
};
void printSlangIRAssembly(StringBuilder& builder, IRModule* module);
diff --git a/source/slang/lower-to-ir.cpp b/source/slang/lower-to-ir.cpp
index 06ad66bc4..7fce0c385 100644
--- a/source/slang/lower-to-ir.cpp
+++ b/source/slang/lower-to-ir.cpp
@@ -79,7 +79,7 @@ struct SubscriptInfo : ExtendedValueInfo
struct BoundSubscriptInfo : ExtendedValueInfo
{
DeclRef<SubscriptDecl> declRef;
- IRType* type;
+ RefPtr<Type> type;
List<IRValue*> args;
UInt genericArgCount;
};
@@ -218,8 +218,8 @@ struct BoundMemberInfo : ExtendedValueInfo
//
struct SwizzledLValueInfo : ExtendedValueInfo
{
- // IR-level The type of the expression.
- IRType* type;
+ // The type of the expression.
+ RefPtr<Type> type;
// The base expression (this should be an l-value)
LoweredValInfo base;
@@ -355,7 +355,7 @@ LoweredValInfo emitCompoundAssignOp(
auto leftVal = builder->emitLoad(leftPtr);
- IRInst* innerArgs[] = { leftVal, rightVal };
+ IRValue* innerArgs[] = { leftVal, rightVal };
auto innerOp = builder->emitIntrinsicInst(type, op, 2, innerArgs);
builder->emitStore(leftPtr, innerOp);
@@ -363,23 +363,31 @@ LoweredValInfo emitCompoundAssignOp(
return LoweredValInfo::ptr(leftPtr);
}
-IRInst* getOneValOfType(
+IRValue* getOneValOfType(
IRGenContext* context,
IRType* type)
{
- switch(type->op)
+ if (auto basicType = dynamic_cast<BasicExpressionType*>(type))
{
- case kIROp_Int32Type:
- case kIROp_UInt32Type:
- return context->irBuilder->getIntValue(type, 1);
+ switch (basicType->baseType)
+ {
+ case BaseType::Int:
+ case BaseType::UInt:
+ case BaseType::UInt64:
+ return context->irBuilder->getIntValue(type, 1);
- case kIROp_Float32Type:
- return context->irBuilder->getFloatValue(type, 1.0);
+ case BaseType::Float:
+ case BaseType::Double:
+ return context->irBuilder->getFloatValue(type, 1.0);
- default:
- SLANG_UNEXPECTED("inc/dec type");
- return nullptr;
+ default:
+ break;
+ }
}
+ // TODO: should make sure to handle vector and matrix types here
+
+ SLANG_UNEXPECTED("inc/dec type");
+ return nullptr;
}
LoweredValInfo emitPreOp(
@@ -396,9 +404,9 @@ LoweredValInfo emitPreOp(
auto preVal = builder->emitLoad(argPtr);
- IRInst* oneVal = getOneValOfType(context, type);
+ IRValue* oneVal = getOneValOfType(context, type);
- IRInst* innerArgs[] = { preVal, oneVal };
+ IRValue* innerArgs[] = { preVal, oneVal };
auto innerOp = builder->emitIntrinsicInst(type, op, 2, innerArgs);
builder->emitStore(argPtr, innerOp);
@@ -420,9 +428,9 @@ LoweredValInfo emitPostOp(
auto preVal = builder->emitLoad(argPtr);
- IRInst* oneVal = getOneValOfType(context, type);
+ IRValue* oneVal = getOneValOfType(context, type);
- IRInst* innerArgs[] = { preVal, oneVal };
+ IRValue* innerArgs[] = { preVal, oneVal };
auto innerOp = builder->emitIntrinsicInst(type, op, 2, innerArgs);
builder->emitStore(argPtr, innerOp);
@@ -647,10 +655,7 @@ struct LoweredTypeInfo
Simple,
};
- union
- {
- IRType* type;
- };
+ RefPtr<IRType> type;
Flavor flavor;
LoweredTypeInfo()
@@ -743,6 +748,35 @@ LoweredValInfo lowerDecl(
DeclBase* decl,
Layout* layout);
+IRType* getIntType(
+ IRGenContext* context)
+{
+ return context->getSession()->getBuiltinType(BaseType::Int);
+}
+
+// Get a pointer type to the given element type
+RefPtr<PtrType> getPtrType(
+ IRGenContext* context,
+ IRType* valueType)
+{
+ return context->getSession()->getPtrType(valueType);
+}
+
+RefPtr<IRFuncType> getFuncType(
+ IRGenContext* context,
+ UInt paramCount,
+ RefPtr<IRType> const* paramTypes,
+ IRType* resultType)
+{
+ RefPtr<FuncType> funcType = new FuncType();
+ funcType->resultType = resultType;
+ for (UInt pp = 0; pp < paramCount; ++pp)
+ {
+ funcType->paramTypes.Add(paramTypes[pp]);
+ }
+ return funcType;
+}
+
//
struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, LoweredTypeInfo>
@@ -761,7 +795,7 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
// TODO: it is a bit messy here that the `ConstantIntVal` representation
// has no notion of a *type* associated with the value...
- auto type = getBuilder()->getBaseType(BaseType::Int);
+ auto type = getIntType(context);
return LoweredValInfo::simple(getBuilder()->getIntValue(type, val->value));
}
@@ -772,16 +806,7 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
LoweredTypeInfo visitFuncType(FuncType* type)
{
- LoweredValInfo loweredFunc = ensureDecl(context, type->declRef);
- auto loweredFuncVal = getSimpleVal(context, loweredFunc);
-
- // HACK: deal with the case where the decl might not
- // lower to anything, and so we don't have a type to
- // work with.
- if (!loweredFuncVal)
- return LoweredTypeInfo();
-
- return loweredFuncVal->getType();
+ return LoweredTypeInfo(type);
}
void addGenericArgs(List<IRValue*>* ioArgs, DeclRefBase declRef)
@@ -799,12 +824,16 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
LoweredTypeInfo visitDeclRefType(DeclRefType* type)
{
+#if 1
+ // TODO: is there actually anything to be done at this point?
+ return LoweredTypeInfo(type);
+#else
// We need to detect builtin/intrinsic types here, since they should map to custom modifiers
// We need to catch builtin/intrinsic types here
if( auto intrinsicTypeMod = type->declRef.getDecl()->FindModifier<IntrinsicTypeModifier>() )
{
auto builder = getBuilder();
- auto intType = builder->getBaseType(BaseType::Int);
+ auto intType = getIntType(context);
//
List<IRValue*> irArgs;
for( auto val : intrinsicTypeMod->irOperands )
@@ -831,61 +860,32 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
default:
SLANG_UNIMPLEMENTED_X("type lowering");
}
-
+#endif
}
LoweredTypeInfo visitBasicExpressionType(BasicExpressionType* type)
{
- return getBuilder()->getBaseType(type->baseType);
+ return LoweredTypeInfo(type);
}
LoweredTypeInfo visitVectorExpressionType(VectorExpressionType* type)
{
- auto irElementType = lowerSimpleType(context, type->elementType);
- auto irElementCount = lowerSimpleVal(context, type->elementCount);
-
- return getBuilder()->getVectorType(irElementType, irElementCount);
+ return LoweredTypeInfo(type);
}
LoweredTypeInfo visitMatrixExpressionType(MatrixExpressionType* type)
{
- auto irElementType = lowerSimpleType(context, type->getElementType());
- auto irRowCount = lowerSimpleVal(context, type->getRowCount());
- auto irColumnCount = lowerSimpleVal(context, type->getColumnCount());
-
- return getBuilder()->getMatrixType(irElementType, irRowCount, irColumnCount);
+ return LoweredTypeInfo(type);
}
- LoweredTypeInfo getArrayType(
- LoweredTypeInfo const& loweredElementType,
- IRValue* irElementCount)
+ LoweredTypeInfo visitArrayExpressionType(ArrayExpressionType* type)
{
- switch (loweredElementType.flavor)
- {
- case LoweredTypeInfo::Flavor::Simple:
- return getBuilder()->getArrayType(
- loweredElementType.type,
- irElementCount);
- break;
-
- default:
- SLANG_UNEXPECTED("array element type");
- break;
- }
+ return LoweredTypeInfo(type);
}
- LoweredTypeInfo visitArrayExpressionType(ArrayExpressionType* type)
+ LoweredTypeInfo visitIRBasicBlockType(IRBasicBlockType* type)
{
- auto loweredElementType = lowerType(context, type->baseType);
- if (auto elementCount = type->ArrayLength)
- {
- auto irElementCount = lowerSimpleVal(context, elementCount);
- return getArrayType(loweredElementType, irElementCount);
- }
- else
- {
- return getArrayType(loweredElementType, nullptr);
- }
+ return LoweredTypeInfo(type);
}
};
@@ -1017,9 +1017,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
if (auto fieldDeclRef = declRef.As<StructField>())
{
// Okay, easy enough: we have a reference to a field of a struct type...
-
- auto loweredField = ensureDecl(context, fieldDeclRef);
- return extractField(loweredType, loweredBase, loweredField);
+ return extractField(loweredType, loweredBase, fieldDeclRef);
}
else if (auto callableDeclRef = declRef.As<CallableDecl>())
{
@@ -1045,14 +1043,12 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
// in order for a dereference to make senese, so we just
// need to extract the value type from that pointer here.
//
- auto loweredBaseVal = getSimpleVal(context, loweredBase);
- auto loweredBaseType = loweredBaseVal->getType();
- switch( loweredBaseType->op )
- {
- case kIROp_PtrType:
- // TODO: should we enumerate these explicitly?
- case kIROp_ConstantBufferType:
- case kIROp_TextureBufferType:
+ IRValue* loweredBaseVal = getSimpleVal(context, loweredBase);
+ RefPtr<Type> loweredBaseType = loweredBaseVal->getType();
+
+ if (loweredBaseType->As<PointerLikeType>()
+ || loweredBaseType->As<PtrType>())
+ {
// Note that we do *not* perform an actual `load` operation
// here, but rather just use the pointer value to construct
// an appropriate `LoweredValInfo` representing the underlying
@@ -1064,8 +1060,9 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
// and is just a bit of pointer math.
//
return LoweredValInfo::ptr(loweredBaseVal);
-
- default:
+ }
+ else
+ {
SLANG_UNIMPLEMENTED_X("codegen for deref expression");
return LoweredValInfo();
}
@@ -1472,7 +1469,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
case LoweredValInfo::Flavor::Ptr:
return LoweredValInfo::ptr(
builder->emitElementAddress(
- builder->getPtrType(getSimpleType(type)),
+ getPtrType(context, getSimpleType(type)),
baseVal.val,
indexVal));
@@ -1484,9 +1481,9 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
}
LoweredValInfo extractField(
- LoweredTypeInfo fieldType,
- LoweredValInfo base,
- LoweredValInfo field)
+ LoweredTypeInfo fieldType,
+ LoweredValInfo base,
+ DeclRef<StructField> field)
{
switch (base.flavor)
{
@@ -1497,7 +1494,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
getBuilder()->emitFieldExtract(
getSimpleType(fieldType),
irBase,
- (IRStructField*) getSimpleVal(context, field)));
+ getBuilder()->getDeclRefVal(field)));
}
break;
@@ -1509,9 +1506,9 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
IRValue* irBasePtr = base.val;
return LoweredValInfo::ptr(
getBuilder()->emitFieldAddress(
- getBuilder()->getPtrType(getSimpleType(fieldType)),
+ getPtrType(context, getSimpleType(fieldType)),
irBasePtr,
- (IRStructField*) getSimpleVal(context, field)));
+ getBuilder()->getDeclRefVal(field)));
}
break;
}
@@ -1598,7 +1595,7 @@ struct RValueExprLoweringVisitor : ExprLoweringVisitorBase<RValueExprLoweringVis
auto builder = getBuilder();
- auto irIntType = builder->getBaseType(BaseType::Int);
+ auto irIntType = getIntType(context);
UInt elementCount = (UInt)expr->elementCount;
IRValue* irElementIndices[4];
@@ -1661,34 +1658,22 @@ struct StmtLoweringVisitor : StmtVisitor<StmtLoweringVisitor>
void insertBlock(IRBlock* block)
{
auto builder = getBuilder();
- auto parent = builder->parentInst;
- IRBlock* prevBlock = nullptr;
- IRFunc* parentFunc = nullptr;
-
- switch (parent->op)
- {
- case kIROp_Block:
- prevBlock = (IRBlock*)parent;
- parentFunc = prevBlock->getParent();
- break;
-
- default:
- SLANG_UNEXPECTED("bad parent kind for block");
- return;
- }
+ auto prevBlock = builder->block;
+ auto parentFunc = prevBlock->parentFunc;
// If the previous block doesn't already have
// a terminator instruction, then be sure to
// emit a branch to the new block.
- if (!isTerminatorInst(prevBlock->lastChild))
+ if (!isTerminatorInst(prevBlock->lastInst))
{
builder->emitBranch(block);
}
- builder->parentInst = parentFunc;
- builder->addInst(block);
- builder->parentInst = block;
+ parentFunc->addBlock(block);
+
+ builder->func = parentFunc;
+ builder->block = block;
}
// Start a new block at the current location.
@@ -2025,8 +2010,8 @@ top:
auto loweredBase = swizzleInfo->base;
// Load from the base value:
- IRInst* irLeftVal = getSimpleVal(context, loweredBase);
- auto irRightVal = getSimpleVal(context, right);
+ IRValue* irLeftVal = getSimpleVal(context, loweredBase);
+ IRValue* irRightVal = getSimpleVal(context, right);
// Now apply the swizzle
IRInst* irSwizzled = builder->emitSwizzleSet(
@@ -2066,7 +2051,7 @@ top:
emitCallToDeclRef(
context,
- builder->getVoidType(),
+ context->getSession()->getVoidType(),
setterDeclRef,
allArgs,
subscriptInfo->genericArgCount);
@@ -2153,8 +2138,67 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
return LoweredValInfo();
}
+ bool isGlobalVarDecl(VarDeclBase* decl)
+ {
+ auto parent = decl->ParentDecl;
+ if (dynamic_cast<ModuleDecl*>(parent))
+ {
+ // Variable declared at global scope? -> Global.
+ return true;
+ }
+
+ return false;
+ }
+
+ LoweredValInfo lowerGlobalVarDecl(VarDeclBase* decl)
+ {
+ auto varType = lowerSimpleType(context, decl->getType());
+
+ IRAddressSpace addressSpace = kIRAddressSpace_Default;
+ if (decl->HasModifier<HLSLGroupSharedModifier>())
+ {
+ addressSpace = kIRAddressSpace_GroupShared;
+ }
+
+ auto builder = getBuilder();
+ auto irGlobal = builder->createGlobalVar(varType);
+
+ if (decl)
+ {
+ builder->addHighLevelDeclDecoration(irGlobal, decl);
+ }
+
+ if (auto layout = getLayout())
+ {
+ builder->addLayoutDecoration(irGlobal, layout);
+ }
+
+ // A global variable's SSA value is a *pointer* to
+ // the underlying storage.
+ auto globalVal = LoweredValInfo::ptr(irGlobal);
+ context->shared->declValues.Add(
+ DeclRef<VarDeclBase>(decl, nullptr),
+ globalVal);
+
+ if( auto initExpr = decl->initExpr )
+ {
+ // TODO: need to handle global with initializer!
+ }
+
+ getBuilder()->getModule()->globalValues.Add(irGlobal);
+
+ return globalVal;
+ }
+
LoweredValInfo visitVarDeclBase(VarDeclBase* decl)
{
+ // Detect global (or effectively global) variables
+ // and handle them differently.
+ if (isGlobalVarDecl(decl))
+ {
+ return lowerGlobalVarDecl(decl);
+ }
+
// A user-defined variable declaration will usually turn into
// an `alloca` operation for the variable's storage,
// plus some code to initialize it and then store to the variable.
@@ -2219,6 +2263,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
LoweredValInfo visitAggTypeDecl(AggTypeDecl* decl)
{
+#if 0
// User-defined aggregate type: need to translate into
// a corresponding IR aggregate type.
@@ -2257,6 +2302,10 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
builder->addInst(irStruct);
return LoweredValInfo::simple(irStruct);
+#else
+ // TODO: What is there to do with a `struct` type?
+ return LoweredValInfo();
+#endif
}
// Sometimes we need to refer to a declaration the way that it would be specialized
@@ -2592,7 +2641,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
}
void trySetMangledName(
- IRInst* inst,
+ IRFunc* irFunc,
Decl* decl)
{
// We want to generate a mangled name for the given declaration and attach
@@ -2605,8 +2654,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
String mangledName = getMangledName(decl);
- auto decoration = getBuilder()->addDecoration<IRMangledNameDecoration>(inst);
- decoration->mangledName = mangledName;
+ irFunc->mangledName = mangledName;
}
@@ -2649,11 +2697,11 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
// need to create an IR function here
IRFunc* irFunc = subBuilder->createFunc();
- subBuilder->parentInst = irFunc;
+ subBuilder->func = irFunc;
trySetMangledName(irFunc, decl);
- List<IRType*> paramTypes;
+ List<RefPtr<Type>> paramTypes;
// We first need to walk the generic parameters (if any)
// because these will influence the declared type of
@@ -2662,6 +2710,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
for( auto genericParamDecl : parameterLists.genericParams )
{
UInt genericParamIndex = genericParamCounter++;
+#if 0
if( auto genericTypeParamDecl = dynamic_cast<GenericTypeParamDecl*>(genericParamDecl) )
{
// In the logical type for the function, a generic
@@ -2675,10 +2724,11 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
// to the appropriate generic parameter position.
IRType* irParameterType = context->irBuilder->getGenericParameterType(genericParamIndex);
- LoweredValInfo LoweredValInfo = LoweredValInfo::simple(irParameterType);
+ LoweredValInfo LoweredValInfo = LoweredValInfo::type(irParameterType);
subContext->shared->declValues[makeDeclRef(genericTypeParamDecl)] = LoweredValInfo;
}
else
+#endif
{
// TODO: handle the other cases here.
SLANG_UNEXPECTED("generic parameter kind");
@@ -2702,7 +2752,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
//
// TODO: Is this the best representation we can use?
- auto irPtrType = subBuilder->getPtrType(irParamType);
+ auto irPtrType = getPtrType(context, irParamType);
paramTypes.Add(irPtrType);
}
}
@@ -2726,14 +2776,15 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
// Instead, a setter always returns `void`
//
- irResultType = getBuilder()->getVoidType();
+ irResultType = context->getSession()->getVoidType();
}
- auto irFuncType = getBuilder()->getFuncType(
+ auto irFuncType = getFuncType(
+ context,
paramTypes.Count(),
paramTypes.Buffer(),
irResultType);
- irFunc->type.init(irFunc, irFuncType);
+ irFunc->type = irFuncType;
if (!decl->Body)
{
@@ -2753,7 +2804,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
// This is a function definition, so we need to actually
// construct IR for the body...
IRBlock* entryBlock = subBuilder->emitBlock();
- subBuilder->parentInst = entryBlock;
+ subBuilder->block = entryBlock;
UInt paramTypeIndex = 0;
for( auto paramInfo : parameterLists.params )
@@ -2771,7 +2822,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
//
// TODO: Is this the best representation we can use?
- auto irPtrType = (IRPtrType*)irParamType;
+ auto irPtrType = irParamType.As<PtrType>();
IRParam* irParamPtr = subBuilder->emitParam(irPtrType);
if(auto paramDecl = paramInfo.decl)
@@ -2829,14 +2880,21 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
// We need to carefully add a terminator instruction to the end
// of the body, in case the user didn't do so.
- if (!isTerminatorInst(subContext->irBuilder->parentInst->lastChild))
+ if (!isTerminatorInst(subContext->irBuilder->block->lastInst))
{
- if (irResultType->op == kIROp_VoidType)
+ if (irResultType->Equals(context->getSession()->getVoidType()))
{
+ // `void`-returning function can get an implicit
+ // return on exit of the body statement.
subContext->irBuilder->emitReturn();
}
else
{
+ // Value-returning function is expected to `return`
+ // on every control-flow path. We need to enforce
+ // this by putting an `unreachable` terminator here,
+ // and then emit a dataflow error if this block
+ // can't be eliminated.
SLANG_UNEXPECTED("Needed a return here");
subContext->irBuilder->emitReturn();
}
@@ -2845,7 +2903,7 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
getBuilder()->addHighLevelDeclDecoration(irFunc, decl);
- getBuilder()->addInst(irFunc);
+ getBuilder()->getModule()->globalValues.Add(irFunc);
return LoweredValInfo::simple(irFunc);
}
@@ -2878,7 +2936,6 @@ LoweredValInfo ensureDecl(
IRBuilder subIRBuilder;
subIRBuilder.shared = context->irBuilder->shared;
- subIRBuilder.parentInst = subIRBuilder.shared->module;
IRGenContext subContext = *context;
@@ -2997,15 +3054,14 @@ IRModule* lowerEntryPointToIR(
SharedIRBuilder sharedBuilderStorage;
SharedIRBuilder* sharedBuilder = &sharedBuilderStorage;
sharedBuilder->module = nullptr;
+ sharedBuilder->session = entryPoint->compileRequest->mSession;
IRBuilder builderStorage;
IRBuilder* builder = &builderStorage;
builder->shared = sharedBuilder;
- builder->parentInst = nullptr;
IRModule* module = builder->createModule();
sharedBuilder->module = module;
- builder->parentInst = module;
context->irBuilder = builder;
diff --git a/source/slang/lower.cpp b/source/slang/lower.cpp
index 85f19f14c..6d848062c 100644
--- a/source/slang/lower.cpp
+++ b/source/slang/lower.cpp
@@ -708,6 +708,12 @@ struct LoweringVisitor
return result;
}
+ RefPtr<Type> visitIRBasicBlockType(IRBasicBlockType* type)
+ {
+ return type;
+ }
+
+
RefPtr<Type> visitErrorType(ErrorType* type)
{
return type;
@@ -732,9 +738,16 @@ struct LoweringVisitor
RefPtr<Type> visitFuncType(FuncType* type)
{
- RefPtr<FuncType> loweredType = getFuncType(
- getSession(),
- translateDeclRef(DeclRef<Decl>(type->declRef)).As<CallableDecl>());
+ RefPtr<FuncType> loweredType = new FuncType();
+ loweredType->resultType = lowerType(type->resultType);
+ for (auto paramType : type->paramTypes)
+ {
+ auto loweredParamType = lowerType(paramType);
+
+ // TODO: it seems like this step needs to scalarize
+ // in the case where a parameter type is a tuple...
+ loweredType->paramTypes.Add(loweredParamType);
+ }
return loweredType;
}
diff --git a/source/slang/syntax.cpp b/source/slang/syntax.cpp
index 475802cb1..4dbb6c05b 100644
--- a/source/slang/syntax.cpp
+++ b/source/slang/syntax.cpp
@@ -216,6 +216,9 @@ void Type::accept(IValVisitor* visitor, void* extra)
overloadedType = new OverloadGroupType();
overloadedType->setSession(this);
+
+ irBasicBlockType = new IRBasicBlockType();
+ irBasicBlockType->setSession(this);
}
Type* Session::getBoolType()
@@ -268,6 +271,29 @@ void Type::accept(IValVisitor* visitor, void* extra)
return errorType;
}
+ Type* Session::getIRBasicBlockType()
+ {
+ return irBasicBlockType;
+ }
+
+ RefPtr<PtrType> Session::getPtrType(
+ RefPtr<Type> valueType)
+ {
+ auto genericDecl = findMagicDecl(
+ this, "PtrType").As<GenericDecl>();
+ auto typeDecl = genericDecl->inner;
+
+ auto substitutions = new Substitutions();
+ substitutions->genericDecl = genericDecl.Ptr();
+ substitutions->args.Add(valueType);
+
+ auto declRef = DeclRef<Decl>(typeDecl.Ptr(), substitutions);
+
+ return DeclRefType::Create(
+ this,
+ declRef)->As<PtrType>();
+ }
+
SyntaxClass<RefObject> Session::findSyntaxClass(Name* name)
{
SyntaxClass<RefObject> syntaxClass;
@@ -545,7 +571,26 @@ void Type::accept(IValVisitor* visitor, void* extra)
else
{
- SLANG_UNEXPECTED("unhandled type");
+ auto classInfo = session->findSyntaxClass(
+ session->getNamePool()->getName(magicMod->name));
+ if (!classInfo.classInfo)
+ {
+ SLANG_UNEXPECTED("unhandled type");
+ }
+
+ auto type = classInfo.createInstance();
+ if (!type)
+ {
+ SLANG_UNEXPECTED("constructor failure");
+ }
+
+ auto declRefType = dynamic_cast<DeclRefType*>(type);
+ if (!declRefType)
+ {
+ SLANG_UNEXPECTED("expected a declaration reference type");
+ }
+ declRefType->declRef = declRef;
+ return declRefType;
}
}
else
@@ -578,6 +623,28 @@ void Type::accept(IValVisitor* visitor, void* extra)
return (int)(int64_t)(void*)this;
}
+ // IRBasicBlockType
+
+ String IRBasicBlockType::ToString()
+ {
+ return "Block";
+ }
+
+ bool IRBasicBlockType::EqualsImpl(Type * /*type*/)
+ {
+ return false;
+ }
+
+ Type* IRBasicBlockType::CreateCanonicalType()
+ {
+ return this;
+ }
+
+ int IRBasicBlockType::GetHashCode()
+ {
+ return (int)(int64_t)(void*)this;
+ }
+
// InitializerListType
String InitializerListType::ToString()
@@ -653,18 +720,43 @@ void Type::accept(IValVisitor* visitor, void* extra)
String FuncType::ToString()
{
- // TODO: a better approach than this
- if (declRef)
- return getText(declRef.GetName());
- else
- return "/* unknown FuncType */";
+ StringBuilder sb;
+ sb << "(";
+ UInt paramCount = getParamCount();
+ for (UInt pp = 0; pp < paramCount; ++pp)
+ {
+ if (pp != 0) sb << ", ";
+ sb << getParamType(pp)->ToString();
+ }
+ sb << ") -> ";
+ sb << getResultType()->ToString();
+ return sb.ProduceString();
}
bool FuncType::EqualsImpl(Type * type)
{
if (auto funcType = type->As<FuncType>())
{
- return declRef == funcType->declRef;
+ auto paramCount = getParamCount();
+ auto otherParamCount = funcType->getParamCount();
+ if (paramCount != otherParamCount)
+ return false;
+
+ for (UInt pp = 0; pp < paramCount; ++pp)
+ {
+ auto paramType = getParamType(pp);
+ auto otherParamType = funcType->getParamType(pp);
+ if (!paramType->Equals(otherParamType))
+ return false;
+ }
+
+ if(!resultType->Equals(funcType->resultType))
+ return false;
+
+ // TODO: if we ever introduce other kinds
+ // of qualification on function types, we'd
+ // want to consider it here.
+ return true;
}
return false;
}
@@ -676,7 +768,16 @@ void Type::accept(IValVisitor* visitor, void* extra)
int FuncType::GetHashCode()
{
- return declRef.GetHashCode();
+ int hashCode = getResultType()->GetHashCode();
+ UInt paramCount = getParamCount();
+ hashCode = combineHash(hashCode, Slang::GetHashCode(paramCount));
+ for (UInt pp = 0; pp < paramCount; ++pp)
+ {
+ hashCode = combineHash(
+ hashCode,
+ getParamType(pp)->GetHashCode());
+ }
+ return hashCode;
}
// TypeType
@@ -782,6 +883,13 @@ void Type::accept(IValVisitor* visitor, void* extra)
return this->declRef.substitutions->args[2].As<IntVal>().Ptr();
}
+ // PtrTypeBase
+
+ Type* PtrTypeBase::getValueType()
+ {
+ return this->declRef.substitutions->args[0].As<Type>().Ptr();
+ }
+
// GenericParamIntVal
bool GenericParamIntVal::EqualsVal(Val* val)
@@ -1161,7 +1269,13 @@ void Type::accept(IValVisitor* visitor, void* extra)
{
auto funcType = new FuncType();
funcType->setSession(session);
- funcType->declRef = declRef;
+
+ funcType->resultType = GetResultType(declRef);
+ for (auto pp : GetParameters(declRef))
+ {
+ funcType->paramTypes.Add(GetType(pp));
+ }
+
return funcType;
}
diff --git a/source/slang/syntax.h b/source/slang/syntax.h
index 8d6a1edbc..f5e22d9b5 100644
--- a/source/slang/syntax.h
+++ b/source/slang/syntax.h
@@ -74,6 +74,29 @@ namespace Slang
kConversionCost_ScalarToVector = 1,
};
+ // TODO(tfoley): We should ditch this enumeration
+ // and just use the IR opcodes that represent these
+ // types directly. The one major complication there
+ // is that the order of the enum values currently
+ // matters, since it determines promotion rank.
+ // We either need to keep that restriction, or
+ // look up promotion rank by some other means.
+ //
+ enum class BaseType
+ {
+ // Note(tfoley): These are ordered in terms of promotion rank, so be vareful when messing with this
+
+ Void = 0,
+ Bool,
+ Int,
+ UInt,
+ UInt64,
+ Half,
+ Float,
+ Double,
+ };
+
+
// Forward-declare all syntax classes
#define SYNTAX_CLASS(NAME, BASE, ...) class NAME;
diff --git a/source/slang/type-defs.h b/source/slang/type-defs.h
index 7883b5c42..7fbefc368 100644
--- a/source/slang/type-defs.h
+++ b/source/slang/type-defs.h
@@ -41,6 +41,20 @@ protected:
)
END_SYNTAX_CLASS()
+// The type of a reference to a basic block
+// in our IR
+SYNTAX_CLASS(IRBasicBlockType, Type)
+RAW(
+public:
+ virtual String ToString() override;
+
+protected:
+ virtual bool EqualsImpl(Type * type) override;
+ virtual Type* CreateCanonicalType() override;
+ virtual int GetHashCode() override;
+)
+END_SYNTAX_CLASS()
+
// A type that takes the form of a reference to some declaration
SYNTAX_CLASS(DeclRefType, Type)
DECL_FIELD(DeclRef<Decl>, declRef)
@@ -221,6 +235,8 @@ END_SYNTAX_CLASS()
// Other cases of generic types known to the compiler
SYNTAX_CLASS(BuiltinGenericType, DeclRefType)
SYNTAX_FIELD(RefPtr<Type>, elementType)
+
+ RAW(Type* getElementType() { return elementType; })
END_SYNTAX_CLASS()
// Types that behave like pointers, in that they can be
@@ -353,6 +369,33 @@ protected:
)
END_SYNTAX_CLASS()
+// Base class for types that map down to
+// simple pointers as part of code generation.
+SYNTAX_CLASS(PtrTypeBase, DeclRefType)
+RAW(
+ // Get the type of the pointed-to value.
+ Type* getValueType();
+)
+END_SYNTAX_CLASS()
+
+// A true (user-visible) pointer type, e.g., `T*`
+SYNTAX_CLASS(PtrType, PtrTypeBase)
+END_SYNTAX_CLASS()
+
+// A type that represents the behind-the-scenes
+// logical pointer that is passed for an `out`
+// or `in out` parameter
+SYNTAX_CLASS(OutTypeBase, PtrTypeBase)
+END_SYNTAX_CLASS()
+
+// The type for an `out` parameter, e.g., `out T`
+SYNTAX_CLASS(OutType, OutTypeBase)
+END_SYNTAX_CLASS()
+
+// The type for an `in out` parameter, e.g., `in out T`
+SYNTAX_CLASS(InOutType, OutTypeBase)
+END_SYNTAX_CLASS()
+
// A type alias of some kind (e.g., via `typedef`)
SYNTAX_CLASS(NamedExpressionType, Type)
DECL_FIELD(DeclRef<TypeDefDecl>, declRef)
@@ -375,17 +418,27 @@ protected:
)
END_SYNTAX_CLASS()
-// Function types are currently used for references to symbols that name
-// either ordinary functions, or "component functions."
-// We do not directly store a representation of the type, and instead
-// use a reference to the symbol to stand in for its logical type
+// A function type is defined by its parameter types
+// and its result type.
SYNTAX_CLASS(FuncType, Type)
- DECL_FIELD(DeclRef<CallableDecl>, declRef)
+
+ // TODO: We may want to preserve parameter names
+ // in the list here, just so that we can print
+ // out friendly names when printing a function
+ // type, even if they don't affect the actual
+ // semantic type underneath.
+
+ FIELD(List<RefPtr<Type>>, paramTypes)
+ FIELD(RefPtr<Type>, resultType)
RAW(
FuncType()
{}
+ UInt getParamCount() { return paramTypes.Count(); }
+ Type* getParamType(UInt index) { return paramTypes[index]; }
+ Type* getResultType() { return resultType; }
+
virtual String ToString() override;
protected:
virtual bool EqualsImpl(Type * type) override;
diff --git a/source/slang/vm.cpp b/source/slang/vm.cpp
index 65ecefc01..d3e44a947 100644
--- a/source/slang/vm.cpp
+++ b/source/slang/vm.cpp
@@ -11,28 +11,43 @@
namespace Slang
{
-// Representation of a type during VM execution
-struct VMTypeImpl
+struct VMValImpl
{
+ // opcode used to construct the value
uint32_t op;
+};
+
+// Representation of a type during VM execution
+struct VMTypeImpl : VMValImpl
+{
+ // number of arguments to the type
+ uint32_t argCount;
// Size and alignment of instances of this
// type.
UInt size;
UInt alignment;
+
+ // operands follow
};
-struct VMType
+
+struct VMVal
{
- VMTypeImpl* impl;
+ VMValImpl* impl;
- UInt getSize() { return impl->size; }
- UInt getAlignment() { return impl->alignment; }
+ VMValImpl* getImpl() { return impl; }
+};
+
+struct VMType : VMVal
+{
+ VMTypeImpl* getImpl() { return (VMTypeImpl*) impl; }
+ UInt getSize() { return getImpl()->size; }
+ UInt getAlignment() { return getImpl()->alignment; }
};
struct VMPtrTypeImpl : VMTypeImpl
{
VMType base;
- uint32_t addressSpace;
};
struct VMReg
@@ -47,20 +62,20 @@ struct VMReg
struct VMConst
{
// Type of the constant
- VMType type;
+ VMType type;
// Operand address to use
void* ptr;
};
-struct VMFrame;
+struct VMModule;
// Information about a function after it has been
// loaded into the VM.
struct VMFunc
{
// The parent module that this function belongs to
- VMFrame* module;
+ VMModule* module;
BCFunc* bcFunc;
VMReg* regs;
@@ -84,8 +99,14 @@ struct VMFrame
// Registers are stored after this point.
};
-struct VMModule : VMFunc
+struct VM;
+
+struct VMModule
{
+ BCModule* bcModule;
+ VM* vm;
+ void** symbols;
+ VMType* types;
};
UInt decodeUInt(BCOp** ioPtr)
@@ -93,13 +114,28 @@ UInt decodeUInt(BCOp** ioPtr)
BCOp* ptr = *ioPtr;
UInt value = *ptr++;
- if( value >= 128 )
+ if( value < 128 )
{
- SLANG_UNEXPECTED("deal with this later");
+ *ioPtr = ptr;
+ return value;
}
- *ioPtr = ptr;
- return value;
+ // Slower path for variable-length encoding
+
+ UInt result = 0;
+ for(;;)
+ {
+ value = value & 0x7F;
+ result = (result << 7) | value;
+
+ if(value < 127)
+ {
+ *ioPtr = ptr;
+ return value;
+ }
+
+ value = *ptr++;
+ }
}
Int decodeSInt(BCOp** ioPtr)
@@ -194,9 +230,15 @@ T& decodeOperand(VMFrame* frame, BCOp** ioIP)
return *decodeOperandPtr<T>(frame, ioIP);
}
+VMType decodeType(VMFrame* frame, BCOp** ioIP)
+{
+ UInt id = decodeUInt(ioIP);
+ return frame->func->module->types[id];
+}
+
VMFunc* loadVMFunc(
BCFunc* bcFunc,
- VMFrame* vmModuleInstance);
+ VMModule* vmModule);
struct VMSizeAlign
{
@@ -231,9 +273,30 @@ VMSizeAlign getVMSymbolSize(BCSymbol* symbol)
return result;
}
+VMType getType(
+ VMModule* vmModule,
+ uint32_t typeID)
+{
+ return vmModule->types[typeID];
+}
+
+void* getGlobalPtr(
+ VMModule* vmModule,
+ uint32_t globalID)
+{
+ return vmModule->symbols[globalID];
+}
+
+VMType getGlobalType(
+ VMModule* vmModule,
+ uint32_t globalID)
+{
+ return getType(vmModule, vmModule->bcModule->symbols[globalID]->typeID);
+}
+
VMFunc* loadVMFunc(
BCFunc* bcFunc,
- VMFrame* vmModuleInstance)
+ VMModule* vmModule)
{
UInt regCount = bcFunc->regCount;
UInt constCount = bcFunc->constCount;
@@ -245,7 +308,7 @@ VMFunc* loadVMFunc(
VMReg* vmRegs = (VMReg*) (vmFunc + 1);
VMConst* vmConsts = (VMConst*) (vmRegs + regCount);
- vmFunc->module = vmModuleInstance;
+ vmFunc->module = vmModule;
vmFunc->bcFunc = bcFunc;
vmFunc->regs = vmRegs;
vmFunc->consts = vmConsts;
@@ -254,31 +317,14 @@ VMFunc* loadVMFunc(
for( UInt rr = 0; rr < regCount; ++rr )
{
BCReg* bcReg = &bcFunc->regs[rr];
- auto typeGlobalID = bcReg->typeGlobalID;
-
- // HACK: when we are loading a module itself, we might
- // not yet know the size for the things it defines
- // (since the module itself might define the type of
- // one of its symbols), so for now we hack it and
- // assume everything at module level is 16 bytes or less.
- //
- // TODO: this also seems like it will cause problems
- // in other contexts (any time the type of a register
- // would depend on an earlier instruction in the same
- // scope) so this needs careful thought.
- VMType vmType = { nullptr };
- UInt regSize = 16;
- UInt regAlign = 8;
-
- if (vmModuleInstance)
- {
- // We expect the type to come from the outer module, so
- // that we can allocate space for it as we go.
- vmType = *(VMType*)getRegPtrImpl(vmModuleInstance, typeGlobalID);
+ auto bcTypeID = bcReg->typeID;
- regSize = vmType.getSize();
- regAlign = vmType.getAlignment();
- }
+ // We expect the type to come from the outer module, so
+ // that we can allocate space for it as we go.
+ auto vmType = getType(vmModule, bcTypeID);
+
+ auto regSize = vmType.getSize();
+ auto regAlign = vmType.getAlignment();
offset = (offset + (regAlign-1)) & ~(regAlign-1);
@@ -292,10 +338,32 @@ VMFunc* loadVMFunc(
for( UInt cc = 0; cc < constCount; ++cc )
{
- auto globalID = bcFunc->consts[cc].globalID;
- auto globalPtr = getRegPtrImpl(vmModuleInstance, globalID);
- vmFunc->consts[cc].ptr = globalPtr;
- vmFunc->consts[cc].type = vmModuleInstance->func->regs[globalID].type;
+ BCConst bcConst = bcFunc->consts[cc];
+ switch( bcConst.flavor )
+ {
+ case kBCConstFlavor_GlobalSymbol:
+ {
+ auto globalID = bcConst.id;
+ vmFunc->consts[cc].ptr = &vmModule->symbols[globalID];
+ vmFunc->consts[cc].type = getGlobalType(vmModule, globalID);
+ }
+ break;
+
+ case kBCConstFlavor_Constant:
+ {
+ auto constID = bcConst.id;
+ auto constInfo = &vmModule->bcModule->constants[constID];
+ vmFunc->consts[cc].ptr = constInfo->ptr;
+ #if 0
+ fprintf(stderr, "CONSANT[%d] : [%p]\n", (int)cc, vmFunc->consts[cc].ptr);
+ fprintf(stderr, "BC [%p] : %d\n", &constInfo->ptr, (int)constInfo->ptr.rawVal);
+ #endif
+ vmFunc->consts[cc].type = getType(vmModule, constInfo->typeID);
+ }
+ break;
+ }
+
+
}
return vmFunc;
@@ -368,7 +436,7 @@ void dumpVMFrame(VMFrame* vmFrame)
case kIROp_PtrType:
{
- fprintf(stderr, ": Ptr<???> = 0x%p", *(void**)regData);
+ fprintf(stderr, ": Ptr<?> = [%p]", *(void**)regData);
}
break;
@@ -416,7 +484,186 @@ struct VMThread
void resumeThread(
VMThread* vmThread);
-VMFrame* loadVMModuleInstance(
+void computeTypeSizeAlign(
+ VMTypeImpl* impl)
+{
+ UInt size = 0;
+ UInt alignment = 0;
+ switch(impl->op)
+ {
+ case kIROp_VoidType:
+ size = 0;
+ break;
+
+ case kIROp_BoolType:
+ size = 1;
+ break;
+
+ case kIROp_Int32Type:
+ case kIROp_UInt32Type:
+ case kIROp_Float32Type:
+ size = 4;
+ break;
+
+ case kIROp_FuncType:
+ case kIROp_PtrType:
+ case kIROp_readWriteStructuredBufferType:
+ case kIROp_structuredBufferType:
+ size = sizeof(void*);
+ break;
+
+ default:
+ SLANG_UNIMPLEMENTED_X("type sizing");
+ impl->size = 0;
+ break;
+ }
+
+ if(!alignment)
+ alignment = size;
+ if(!alignment)
+ alignment = 1;
+
+ impl->size = size;
+ impl->alignment = alignment;
+}
+
+VMType getType(
+ VM* vm,
+ VMTypeImpl* typeImpl)
+{
+ // TODO: need to look up an existing type that matches...
+
+ UInt argCount = typeImpl->argCount;
+ UInt size = sizeof(VMTypeImpl) + argCount*sizeof(VMType);
+
+ VMTypeImpl* impl = (VMTypeImpl*) malloc(size);
+ memcpy(impl, typeImpl, size);
+
+ computeTypeSizeAlign(impl);
+
+ VMType type;
+ type.impl = impl;
+ return type;
+}
+
+VMVal getVal(
+ VMModule* vmModule,
+ UInt index)
+{
+ return vmModule->types[index];
+}
+
+VMType loadVMType(
+ VMModule* vmModule,
+ BCType* bcType)
+{
+ // Need to load type from BC format to VM
+ IROp op = (IROp) bcType->op;
+ switch(bcType->op)
+ {
+ case kIROp_PtrType:
+ {
+ // TODO: need to do some caching!
+ BCPtrType* bcPtrType = (BCPtrType*) bcType;
+
+ VMPtrTypeImpl vmPtrTypeImpl;
+ vmPtrTypeImpl.op = op;
+ vmPtrTypeImpl.argCount = 1;
+ vmPtrTypeImpl.size = sizeof(void*);
+ vmPtrTypeImpl.alignment = sizeof(void*);
+ vmPtrTypeImpl.base = getType(vmModule, bcPtrType->valueType->id);
+
+ auto vmPtrType = getType(vmModule->vm, &vmPtrTypeImpl);
+ return vmPtrType;
+ }
+ break;
+
+ default:
+ {
+ UInt argCount = bcType->argCount;
+
+ UInt size = sizeof(VMTypeImpl) + argCount * sizeof(VMVal);
+
+ VMTypeImpl* impl = (VMTypeImpl*) alloca(size);
+ memset(impl, 0, size);
+ impl->op = bcType->op;
+ impl->argCount = argCount;
+
+ VMVal* args = (VMVal*) (impl + 1);
+ for(UInt aa = 0; aa < argCount; ++aa)
+ {
+ args[aa] = getVal(vmModule, bcType->getArg(aa)->id);
+ }
+
+ return getType(vmModule->vm, impl);
+ }
+
+ SLANG_UNEXPECTED("unimplemented");
+ return VMType();
+ break;
+ }
+}
+
+void* allocateImpl(VM* vm, UInt size, UInt align)
+{
+ void* ptr = malloc(size);
+ memset(ptr, 0, size);
+ return ptr;
+}
+
+void* allocate(VM* vm, VMType type)
+{
+ return allocateImpl(vm, type.getSize(), type.getAlignment());
+}
+
+template<typename T>
+T* allocate(VM* vm)
+{
+ return allocateImpl(vm, sizeof(T), alignof(T));
+}
+
+void* loadVMSymbol(
+ VMModule* vmModule,
+ BCSymbol* bcSymbol)
+{
+ // Need to load type from BC format to VM
+
+ auto vm = vmModule->vm;
+
+ switch(bcSymbol->op)
+ {
+ case kIROp_global_var:
+ {
+ auto type = getType(vmModule, bcSymbol->typeID);
+ assert(type.impl->op == kIROp_PtrType);
+
+ VMPtrTypeImpl* ptrTypeImpl = (VMPtrTypeImpl*) type.impl;
+ auto valueType = ptrTypeImpl->base;
+
+ void* varValue = allocate(vm, valueType);
+ void** varPtr = (void**) allocate(vm, type);
+
+
+ *varPtr = varValue;
+
+ return varPtr;
+ }
+ break;
+
+ case kIROp_Func:
+ {
+ auto bcFunc = (BCFunc*) bcSymbol;
+ VMFunc* vmFunc = loadVMFunc(bcFunc, vmModule);
+ return vmFunc;
+ }
+ break;
+
+ default:
+ return nullptr;
+ }
+}
+
+VMModule* loadVMModuleInstance(
VM* vm,
void const* bytecode,
size_t bytecodeSize)
@@ -425,45 +672,63 @@ VMFrame* loadVMModuleInstance(
BCModule* bcModule = bcHeader->module;
- UInt vmModuleSize = sizeof(VMModule) + bcModule->regCount * sizeof(VMReg);
+ UInt symbolCount = bcModule->symbolCount;
+ UInt typeCount = bcModule->typeCount;
- VMModule* vmModule = (VMModule*) loadVMFunc(bcModule, nullptr);
+ UInt vmModuleSize = sizeof(VMModule)
+ + symbolCount * sizeof(void*)
+ + typeCount * sizeof(VMType);
- // Create a frame to store the loaded symbols, and execute it
- // to initialize them.
- VMFrame* vmModuleInstance = createFrame(vmModule);
- vmModuleInstance->parent = nullptr;
- vmModule->module = vmModuleInstance;
+ VMModule* vmModule = (VMModule*)malloc(vmModuleSize);
+ memset(vmModule, 0, vmModuleSize);
+ void** vmSymbols = (void**)(vmModule + 1);
+ VMType* vmTypes = (VMType*)(vmSymbols + symbolCount);
- VMThread thread;
- thread.frame = vmModuleInstance;
+ vmModule->bcModule = bcModule;
+ vmModule->vm = vm;
+ vmModule->symbols = vmSymbols;
+ vmModule->types = vmTypes;
- resumeThread(&thread);
+ // Initialize types before symbols, since the symbols
+ // will all have types...
+ for(UInt tt = 0; tt < typeCount; ++tt)
+ {
+ BCType* bcType = bcModule->types[tt];
+ vmTypes[tt] = loadVMType(vmModule, bcType);
+ }
+
+ // Now we need to initialize all the VM-level symbols
+ // from their BC-level equivalents.
+ for(UInt ss = 0; ss < symbolCount; ++ss)
+ {
+ BCSymbol* bcSymbol = bcModule->symbols[ss];
+ vmSymbols[ss] = loadVMSymbol(vmModule, bcSymbol);
+ }
- return vmModuleInstance;
+ return vmModule;
}
void* findGlobalSymbolPtr(
- VMFrame* moduleInstance,
+ VMModule* module,
char const* name)
{
// Okay, we need to search through the available
- // "registers" looking for one that gives us a
- // name match.
- //
- BCFunc* bcFunc = moduleInstance->func->bcFunc;
- UInt regCount = bcFunc->regCount;
- for (UInt rr = 0; rr < regCount; ++rr)
+ // symbols, looking for one that gives us a name
+ // match.
+
+ BCModule* bcModule = module->bcModule;
+ UInt symbolCount = bcModule->symbolCount;
+ for(UInt ss = 0; ss < symbolCount; ++ss)
{
- BCReg* bcReg = &bcFunc->regs[rr];
+ BCSymbol* bcSymbol = bcModule->symbols[ss];
- char const* symbolName = bcReg->name;
+ char const* symbolName = bcSymbol->name;
if (!symbolName)
continue;
if(strcmp(symbolName, name) == 0)
- return getRegPtrImpl(moduleInstance, rr);
+ return getGlobalPtr(module, ss);
}
return nullptr;
@@ -516,182 +781,11 @@ void resumeThread(
auto op = (IROp) decodeUInt(&ip);
switch( op )
{
- case kIROp_TypeType:
- {
- // The type of types
- Int argCount = decodeUInt(&ip);
- void* arg0Ptr = decodeOperandPtr<void>(frame, &ip);
- VMType* destPtr = decodeOperandPtr<VMType>(frame, &ip);
-
- auto typeImpl = new VMTypeImpl();
- typeImpl->op = op;
- typeImpl->size = sizeof(VMType);
- typeImpl->alignment = alignof(VMType);
-
- VMType type = { typeImpl };
- *destPtr = type;
- }
- break;
-
- case kIROp_BlockType:
- case kIROp_Int32Type:
- case kIROp_UInt32Type:
- case kIROp_Float32Type:
- case kIROp_BoolType:
- case kIROp_VoidType:
- case kIROp_FuncType: // TODO: we should in principle handle function types here
- {
- // Case to handle types without arguments.
- UInt argCount = decodeUInt(&ip);
- for( UInt aa = 0; aa < argCount; ++aa )
- {
- void* argPtr = decodeOperandPtr<void>(frame, &ip);
- }
- VMType* destPtr = decodeOperandPtr<VMType>(frame, &ip);
-
- auto typeImpl = new VMTypeImpl();
- typeImpl->op = op;
-
- UInt size = 1;
- UInt align = 0;
- switch (op)
- {
- case kIROp_BlockType: size = sizeof(void*); break;
- case kIROp_Int32Type: size = sizeof(int32_t); break;
- case kIROp_UInt32Type: size = sizeof(uint32_t); break;
- case kIROp_Float32Type: size = sizeof(float); break;
- case kIROp_BoolType: size = sizeof(bool); break;
- case kIROp_VoidType: size = 0; align = 1; break;
- default:
- break;
- }
- if (!align) align = size;
- typeImpl->size = size;
- typeImpl->alignment = align;
-
- VMType type = { typeImpl };
- *destPtr = type;
- }
- break;
-
- case kIROp_PtrType:
- {
- // Case to handle types without arguments.
- UInt argCount = decodeUInt(&ip);
- VMType* typeTypePtr = decodeOperandPtr<VMType>(frame, &ip);
- VMType baseType = decodeOperand<VMType>(frame, &ip);
- int32_t addressSpace = decodeOperand<int32_t>(frame, &ip);
- VMType* destPtr = decodeOperandPtr<VMType>(frame, &ip);
-
-
-
- auto typeImpl = new VMPtrTypeImpl();
- typeImpl->op = op;
- typeImpl->size = sizeof(void*);
- typeImpl->alignment = alignof(void*);
- typeImpl->base = baseType;
- typeImpl->addressSpace = addressSpace;
-
- VMType type = { typeImpl };
- *destPtr = type;
- }
- break;
-
- case kIROp_readWriteStructuredBufferType:
- case kIROp_structuredBufferType:
- {
- // Case to handle types without arguments.
- UInt argCount = decodeUInt(&ip);
- VMType* typeTypePtr = decodeOperandPtr<VMType>(frame, &ip);
- VMType baseType = decodeOperand<VMType>(frame, &ip);
- VMType* destPtr = decodeOperandPtr<VMType>(frame, &ip);
-
-
- // TODO: give these their own representations!
- auto typeImpl = new VMPtrTypeImpl();
- typeImpl->op = op;
- typeImpl->base = baseType;
- typeImpl->size = sizeof(void*);
- typeImpl->alignment = sizeof(void*);
-
- VMType type = { typeImpl };
- *destPtr = type;
- }
- break;
-
-
- case kIROp_IntLit:
- {
- VMType type = decodeOperand<VMType>(frame, &ip);
- UInt uVal = decodeUInt(&ip);
- void* destPtr = decodeOperandPtr<void>(frame, &ip);
-
- switch( type.impl->op )
- {
- case kIROp_Int32Type:
- *(int32_t*)destPtr = int32_t(uVal);
- break;
-
- case kIROp_UInt32Type:
- *(uint32_t*)destPtr = uint32_t(uVal);
- break;
-
- default:
- SLANG_UNEXPECTED("integer type");
- break;
- }
-
- }
- break;
-
- case kIROp_FloatLit:
- {
- VMType type = decodeOperand<VMType>(frame, &ip);
-
- static const UInt size = sizeof(IRFloatingPointValue);
- IRFloatingPointValue value;
- memcpy(&value, ip, size);
- ip += size;
- void* destPtr = decodeOperandPtr<void>(frame, &ip);
-
- switch( type.impl->op )
- {
- case kIROp_Float32Type:
- *(float*)destPtr = float(value);
- break;
-
- default:
- SLANG_UNEXPECTED("float type");
- break;
- }
- }
- break;
-
- case kIROp_boolConst:
- {
- bool val = (*ip++) != 0;
- bool* destPtr = decodeOperandPtr<bool>(frame, &ip);
- *destPtr = val;
- }
- break;
-
- case kIROp_Func:
- {
- UInt nestedID = decodeUInt(&ip);
- void* destPtr = decodeOperandPtr<void>(frame, &ip);
-
- BCSymbol* bcSymbol = frame->func->bcFunc->nestedSymbols[nestedID];
- BCFunc* bcFunc = (BCFunc*)bcSymbol;
- VMFunc* vmFunc = loadVMFunc(bcFunc, frame->func->module);
-
- *(VMFunc**)destPtr = vmFunc;
- }
- break;
-
case kIROp_Var:
{
// This instruction represents the `alloca` for a variable of some type.
+ VMType type = decodeType(frame, &ip);
UInt argCount = decodeUInt(&ip);
void* argPtrs[16] = { 0 };
for( UInt aa = 0; aa < argCount; ++aa )
@@ -714,10 +808,17 @@ void resumeThread(
case kIROp_Store:
{
// An ordinary memory store
- VMType type = decodeOperand<VMType>(frame, &ip);
+ VMType type = decodeType(frame, &ip);
void* dest = decodeOperand<void*>(frame, &ip);
void* src = decodeOperandPtr<void>(frame, &ip);
+#if 0
+ fprintf(stderr, "STORE *[%p] = [%p] // size: %d\n",
+ dest,
+ src,
+ (int) type.getSize());
+#endif
+
memcpy(dest, src, type.getSize());
}
break;
@@ -725,7 +826,7 @@ void resumeThread(
case kIROp_Load:
{
// An ordinary memory store
- VMType type = decodeOperand<VMType>(frame, &ip);
+ VMType type = decodeType(frame, &ip);
void* src = decodeOperand<void*>(frame, &ip);
void* dest = decodeOperandPtr<void>(frame, &ip);
@@ -735,6 +836,7 @@ void resumeThread(
case kIROp_BufferLoad:
{
+ VMType type = decodeType(frame, &ip);
UInt argCount = decodeUInt(&ip);
void* argPtrs[16] = { 0 };
for( UInt aa = 0; aa < argCount; ++aa )
@@ -745,9 +847,8 @@ void resumeThread(
void* dest = decodeOperandPtr<void>(frame, &ip);
- VMType type = *(VMType*)argPtrs[0];
- char* bufferData = *(char**)argPtrs[1];
- uint32_t index = *(uint32_t*)argPtrs[2];
+ char* bufferData = *(char**)argPtrs[0];
+ uint32_t index = *(uint32_t*)argPtrs[1];
auto size = type.getSize();
char* elementData = bufferData + index*size;
@@ -757,9 +858,9 @@ void resumeThread(
case kIROp_BufferStore:
{
+ VMType resultType = decodeType(frame, &ip);
UInt argCount = decodeUInt(&ip);
- VMType* resultTypePtr = decodeOperandPtr<VMType>(frame, &ip);
char* bufferData = decodeOperand<char*>(frame, &ip);
uint32_t index = decodeOperand<uint32_t>(frame, &ip);
@@ -774,8 +875,10 @@ void resumeThread(
break;
case kIROp_Call:
{
+ VMType type = decodeType(frame, &ip);
UInt operandCount = decodeUInt(&ip);
- VMType type = decodeOperand<VMType>(frame, &ip);
+
+ // First operand is the callee function
VMFunc* func = decodeOperand<VMFunc*>(frame, &ip);
// Okay, we need to create a frame to prepare the call
@@ -784,7 +887,7 @@ void resumeThread(
// Remaining arguments should populate the
// first N registers of the callee
- UInt argCount = operandCount - 2;
+ UInt argCount = operandCount - 1;
for( UInt aa = 0; aa < argCount; ++aa )
{
void* argPtr = decodeOperandPtr<void>(frame, &ip);
@@ -836,8 +939,8 @@ void resumeThread(
case kIROp_ReturnVal:
{
+ VMType instType = decodeType(frame, &ip);
UInt argCount = decodeUInt(&ip);
- void* typePtr = decodeOperandPtr<void>(frame, &ip);
void* argPtr = decodeOperandPtr<void>(frame, &ip);
VMFrame* oldFrame = frame;
@@ -868,8 +971,8 @@ void resumeThread(
// For now our encoding is very regular, so we can decode without
// knowing too much about an instruction...
+ VMType type = decodeType(frame, &ip);
UInt argCount = decodeUInt(&ip);
- void* typePtr = decodeOperandPtr<void>(frame, &ip);
Int destinationBlock = decodeSInt(&ip);
for( UInt aa = 2; aa < argCount; ++aa )
{
@@ -892,8 +995,8 @@ void resumeThread(
// For now our encoding is very regular, so we can decode without
// knowing too much about an instruction...
+ VMType type = decodeType(frame, &ip);
UInt argCount = decodeUInt(&ip);
- void* typePtr = decodeOperandPtr<void>(frame, &ip);
bool* condition = decodeOperandPtr<bool>(frame, &ip);
Int trueBlockID = decodeSInt(&ip);
Int falseBlockID = decodeSInt(&ip);
@@ -917,9 +1020,9 @@ void resumeThread(
// For now our encoding is very regular, so we can decode without
// knowing too much about an instruction...
+ VMType resultType = decodeType(frame, &ip);
UInt argCount = decodeUInt(&ip);
void* argPtrs[16] = { 0 };
- VMType resultType = decodeOperand<VMType>(frame, &ip);
auto leftOpnd = decodeOperandPtrAndType(frame, &ip);
auto type = leftOpnd.type;
auto leftPtr = leftOpnd.ptr;
@@ -942,8 +1045,8 @@ void resumeThread(
case kIROp_Mul:
{
+ VMType type = decodeType(frame, &ip);
UInt argCount = decodeUInt(&ip);
- VMType type = decodeOperand<VMType>(frame, &ip);
void* leftPtr = decodeOperandPtr<void>(frame, &ip);
void* rightPtr = decodeOperandPtr<void>(frame, &ip);
@@ -964,8 +1067,8 @@ void resumeThread(
case kIROp_Sub:
{
+ VMType type = decodeType(frame, &ip);
UInt argCount = decodeUInt(&ip);
- VMType type = decodeOperand<VMType>(frame, &ip);
void* leftPtr = decodeOperandPtr<void>(frame, &ip);
void* rightPtr = decodeOperandPtr<void>(frame, &ip);
@@ -989,6 +1092,7 @@ void resumeThread(
// For now our encoding is very regular, so we can decode without
// knowing too much about an instruction...
+ VMType type = decodeType(frame, &ip);
UInt argCount = decodeUInt(&ip);
void* argPtrs[16] = { 0 };
for( UInt aa = 0; aa < argCount; ++aa )
@@ -1039,7 +1143,7 @@ SLANG_API void* SlangVMModule_findGlobalSymbolPtr(
char const* name)
{
return (SlangVMFunc*) Slang::findGlobalSymbolPtr(
- (Slang::VMFrame*) module,
+ (Slang::VMModule*) module,
name);
}