From 54f016e7ef36b7505bf47d188cf4b7e1fdc443a4 Mon Sep 17 00:00:00 2001 From: Tim Foley Date: Wed, 4 Oct 2017 13:54:25 -0700 Subject: IR: overhaul IR design/implementation (#195) * IR: overhaul IR design/implementation Closes #192 Closes #188 This is a major overhaul of how the IR is implemented, with the primary goal of just using the AST-level type representation as the IR's type representation, rather than inventing an entire shadow set of types (as captured in issue #192). One consequence of this choice is that types in the IR are no longer explicit "instructions" and are not represented as ordinary operands (so a bunch of `+ 1` cases end up going away when enumerating ordinary operands). Along the way I also got rid of the embedded IDs in the IR (issue #188) because this wasn't too hard to deal with at the same time. Another related change was to split the `IRValue` and `IRInst` cases, so that there are values that are not also instructions. Non-instruction values are now used to represent literals, references to declarations, and would eventually be used for an `undef` value if we need one. IR functions, global variables, and basic blocks are all values (because they can appear as operands), but not instructions. The main benefit of this approach is that the top-level structure of a bytecode (BC) module is much simpler to understand and walk, and BC-level types are represented much more directly (such that we could conceivably use them for reflection soon). * fixup: 64-bit build fix * fixup: try to silence clang's pedantic dependent-type errors * fixup: bug in VM loading of constants --- source/slang/bytecode.cpp | 555 ++++++++++++++------- source/slang/bytecode.h | 102 +++- source/slang/check.cpp | 27 +- source/slang/compiler.h | 8 + source/slang/core.meta.slang | 11 + source/slang/core.meta.slang.h | 12 + source/slang/emit.cpp | 487 ++++++++++++------ source/slang/ir-inst-defs.h | 29 +- source/slang/ir-insts.h | 236 +++------ source/slang/ir.cpp | 1065 ++++++++++++++++++---------------------- source/slang/ir.h | 261 +++++----- source/slang/lower-to-ir.cpp | 320 +++++++----- source/slang/lower.cpp | 19 +- source/slang/syntax.cpp | 132 ++++- source/slang/syntax.h | 23 + source/slang/type-defs.h | 63 ++- source/slang/vm.cpp | 616 +++++++++++++---------- tests/ir/loop.slang.expected | 129 ++--- tools/eval-test/main.cpp | 6 +- 19 files changed, 2368 insertions(+), 1733 deletions(-) diff --git a/source/slang/bytecode.cpp b/source/slang/bytecode.cpp index e412a5b94..854d9bf34 100644 --- a/source/slang/bytecode.cpp +++ b/source/slang/bytecode.cpp @@ -19,8 +19,8 @@ struct SharedBytecodeGenerationContext; template struct BytecodeGenerationPtr { - SharedBytecodeGenerationContext* sharedContext; UInt offset; + SharedBytecodeGenerationContext* sharedContext; BytecodeGenerationPtr() : sharedContext(nullptr) @@ -49,31 +49,40 @@ struct BytecodeGenerationPtr , offset(ptr.offset) {} - operator BCPtr() + template + BytecodeGenerationPtr bitCast() const + { + return BytecodeGenerationPtr(sharedContext, offset); + } + + operator BCPtr() const { return BCPtr(getPtr()); } - T* operator->() + T* operator->() const { return getPtr(); } - T& operator*() + T& operator*() const { return *getPtr(); } - T& operator[](UInt index) + T& operator[](UInt index) const { return getPtr()[index]; } - BytecodeGenerationPtr operator+(Int index) + BytecodeGenerationPtr operator+(Int index) const { + UInt size = sizeof(T); + Int delta = index * sizeof(T); + UInt newOffset = offset + delta; return BytecodeGenerationPtr( sharedContext, - offset + index*sizeof(T)); + newOffset); } T* getPtr() const; @@ -93,14 +102,27 @@ struct SharedBytecodeGenerationContext // The final generated bytecode stream List bytecode; - // Map from a global symbol to its global ID - Dictionary mapGlobalSymbolToGLobalID; + // Map from an IR value to a global entity + // that encodes it: + Dictionary mapValueToGlobal; + + // Types that have been emitted + List> bcTypes; + Dictionary mapTypeToID; + + // Compile-time constant values that need + // to be emitted... + List constants; }; struct BytecodeGenerationContext { SharedBytecodeGenerationContext* shared; + // The bytecode of the current symbol being + // output. + List currentBytecode; + // The function that is in scope for this context IRFunc* currentIRFunc; @@ -110,11 +132,7 @@ struct BytecodeGenerationContext // Map an instruction to its ID for use local // to the current context - Dictionary mapInstToLocalID; - - // Map an instruction to the ID for its auxiliary - // symbol data - Dictionary mapInstToNestedID; + Dictionary mapInstToLocalID; }; template @@ -169,7 +187,7 @@ void encodeUInt8( BytecodeGenerationContext* context, uint8_t value) { - context->shared->bytecode.Add(value); + context->currentBytecode.Add(value); } void encodeUInt( @@ -220,50 +238,98 @@ void encodeSInt( encodeUInt(context, uValue); } -Int getLocalID( +BCConst getGlobalValue( BytecodeGenerationContext* context, - IRInst* inst) + IRValue* value) { - Int localID = 0; - if( context->mapInstToLocalID.TryGetValue(inst, localID) ) - { - return localID; - } + BCConst bcConst; + if( context->shared->mapValueToGlobal.TryGetValue(value, bcConst) ) + return bcConst; + + // Next we need to check for things that can be mapped to + // global IDs on the fly. - Int globalID = 0; - if( context->shared->mapGlobalSymbolToGLobalID.TryGetValue(inst, globalID) ) + switch( value->op ) { - BCConst bcConst; - bcConst.globalID = globalID; + case kIROp_IntLit: + { + UInt constID = context->shared->constants.Count(); + context->shared->constants.Add(value); - UInt remappedSymbolIndex = context->remappedGlobalSymbols.Count(); - context->remappedGlobalSymbols.Add(bcConst); + BCConst bcConst; + bcConst.flavor = kBCConstFlavor_Constant; + bcConst.id = constID; - localID = ~remappedSymbolIndex; - context->mapInstToLocalID.Add(inst, localID); - return localID; + context->shared->mapValueToGlobal.Add(value, bcConst); + + return bcConst; + } + break; + + default: + break; } SLANG_UNEXPECTED("no ID for inst"); - return -9999; + bcConst.flavor = (BCConstFlavor) -1; + bcConst.id = -9999; + return bcConst; +} + +Int getLocalID( + BytecodeGenerationContext* context, + IRValue* value) +{ + Int localID = 0; + if( context->mapInstToLocalID.TryGetValue(value, localID) ) + { + return localID; + } + + BCConst bcConst = getGlobalValue(context, value); + UInt remappedSymbolIndex = context->remappedGlobalSymbols.Count(); + context->remappedGlobalSymbols.Add(bcConst); + + localID = ~remappedSymbolIndex; + context->mapInstToLocalID.Add(value, localID); + return localID; } void encodeOperand( BytecodeGenerationContext* context, - IRInst* operand) + IRValue* operand) { auto id = getLocalID(context, operand); encodeSInt(context, id); } -bool opHasResult(IRInst* inst) +uint32_t getTypeID( + BytecodeGenerationContext* context, + Type* type); + +void encodeOperand( + BytecodeGenerationContext* context, + IRType* type) +{ + encodeUInt(context, getTypeID(context, type)); +} + +bool opHasResult(IRValue* inst) { auto type = inst->getType(); - if( !type || type->op != kIROp_VoidType ) + if (!type) return false; + + // As a bit of a hack right now, we need to check whether + // the function returns the distinguished `Void` type, + // since that is conceptually the same as "not returning + // a value." + if (auto basicType = dynamic_cast(type)) { - return true; + if (basicType->baseType == BaseType::Void) + return false; } - return false; + + return true; } void generateBytecodeForInst( @@ -282,15 +348,16 @@ void generateBytecodeForInst( // auto argCount = inst->getArgCount(); + auto type = inst->getType(); encodeUInt(context, inst->op); + encodeOperand(context, inst->getType()); encodeUInt(context, argCount); for( UInt aa = 0; aa < argCount; ++aa ) { encodeOperand(context, inst->getArg(aa)); } - auto type = inst->getType(); - if( type && type->op == kIROp_VoidType ) + if (!opHasResult(inst)) { // This instructions has no type, so don't emit a destination } @@ -354,6 +421,7 @@ void generateBytecodeForInst( } break; +#if 0 case kIROp_Func: { encodeUInt(context, inst->op); @@ -368,6 +436,7 @@ void generateBytecodeForInst( encodeOperand(context, inst); } break; +#endif case kIROp_Store: { @@ -375,9 +444,9 @@ void generateBytecodeForInst( // We need to encode the type being stored, to make // our lives easier. - encodeOperand(context, inst->getArg(2)->getType()); + encodeOperand(context, inst->getArg(1)->getType()); + encodeOperand(context, inst->getArg(0)); encodeOperand(context, inst->getArg(1)); - encodeOperand(context, inst->getArg(2)); } break; @@ -385,33 +454,166 @@ void generateBytecodeForInst( { encodeUInt(context, inst->op); encodeOperand(context, inst->getType()); - encodeOperand(context, inst->getArg(1)); + encodeOperand(context, inst->getArg(0)); encodeOperand(context, inst); } break; } } -Int getIDForGlobalSymbol( +BytecodeGenerationPtr emitBCType( + BytecodeGenerationContext* context, + Type* type, + IROp op, + BytecodeGenerationPtr const* args, + UInt argCount) +{ + UInt size = sizeof(BCType) + + argCount * sizeof(BCPtr); + + BytecodeGenerationPtr bcAllocation( + context->shared, + allocateRaw(context, size, alignof(BCPtr))); + + BytecodeGenerationPtr bcType = bcAllocation.bitCast(); + auto bcArgs = (bcType + 1).bitCast>(); + + bcType->op = op; + bcType->argCount = argCount; + + for(UInt aa = 0; aa < argCount; ++aa) + { + bcArgs[aa] = args[aa]; + } + + UInt id = context->shared->bcTypes.Count(); + context->shared->mapTypeToID.Add(type, id); + context->shared->bcTypes.Add(bcType); + bcType->id = id; + + return bcType; +} + +BytecodeGenerationPtr emitBCVarArgType( + BytecodeGenerationContext* context, + Type* type, + IROp op, + List> args) +{ + return emitBCType(context, type, op, args.Buffer(), args.Count()); +} + +BytecodeGenerationPtr emitBCType( BytecodeGenerationContext* context, - IRInst* inst) + Type* type, + IROp op) +{ + return emitBCType(context, type, op, nullptr, 0); +} + +BytecodeGenerationPtr emitBCType( + BytecodeGenerationContext* context, + Type* type); + +// Emit a `BCType` representation for the given `Type` +BytecodeGenerationPtr emitBCTypeImpl( + BytecodeGenerationContext* context, + Type* type) +{ + // A NULL type is interpreted as equivalent to `Void` for now. + if( !type ) + { + return emitBCType(context, type, kIROp_VoidType); + } + + if( auto basicType = type->As() ) + { + switch(basicType->baseType) + { + case BaseType::Void: return emitBCType(context, type, kIROp_VoidType); + case BaseType::Bool: return emitBCType(context, type, kIROp_BoolType); + case BaseType::Int: return emitBCType(context, type, kIROp_Int32Type); + case BaseType::UInt: return emitBCType(context, type, kIROp_UInt32Type); + case BaseType::UInt64: return emitBCType(context, type, kIROp_UInt64Type); + case BaseType::Half: return emitBCType(context, type, kIROp_Float16Type); + case BaseType::Float: return emitBCType(context, type, kIROp_Float32Type); + case BaseType::Double: return emitBCType(context, type, kIROp_Float64Type); + + default: + break; + } + } + else if( auto funcType = type->As() ) + { + List> operands; + + operands.Add(emitBCType(context, funcType->resultType).bitCast()); + UInt paramCount = funcType->getParamCount(); + for(UInt pp = 0; pp < paramCount; ++pp) + { + operands.Add(emitBCType(context, funcType->getParamType(pp)).bitCast()); + } + + return emitBCVarArgType(context, type, kIROp_FuncType, operands); + } + else if( auto ptrType = type->As() ) + { + List> operands; + operands.Add(emitBCType(context, ptrType->getValueType()).bitCast()); + return emitBCVarArgType(context, type, kIROp_PtrType, operands); + } + else if( auto rwStructuredBufferType = type->As() ) + { + List> operands; + operands.Add(emitBCType(context, rwStructuredBufferType->elementType).bitCast()); + return emitBCVarArgType(context, type, kIROp_readWriteStructuredBufferType, operands); + } + else if( auto structuredBufferType = type->As() ) + { + List> operands; + operands.Add(emitBCType(context, structuredBufferType->elementType).bitCast()); + return emitBCVarArgType(context, type, kIROp_structuredBufferType, operands); + } + + + SLANG_UNEXPECTED("unimplemented"); + return BytecodeGenerationPtr(); +} + +BytecodeGenerationPtr emitBCType( + BytecodeGenerationContext* context, + Type* type) { - Int globalID; - if(context->shared->mapGlobalSymbolToGLobalID.TryGetValue(inst, globalID)) - return globalID; + auto canonical = type->GetCanonicalType(); + UInt id = 0; + if(context->shared->mapTypeToID.TryGetValue(canonical, id)) + { + return context->shared->bcTypes[id]; + } - SLANG_UNEXPECTED("no such ID"); + BytecodeGenerationPtr bcType = emitBCTypeImpl(context, canonical); + return bcType; } -uint32_t getTypeForGlobalSymbol( +uint32_t getTypeID( BytecodeGenerationContext* context, - IRInst* inst) + Type* type) +{ + // We have a type, and we need to emit it (if we haven't + // already) and return its index in the global type table. + BytecodeGenerationPtr bcType = emitBCType(context, type); + return bcType->id; +} + +uint32_t getTypeIDForGlobalSymbol( + BytecodeGenerationContext* context, + IRValue* inst) { auto type = inst->getType(); if(!type) return 0; - return getIDForGlobalSymbol(context, type); + return getTypeID(context, type); } BytecodeGenerationPtr allocateString( @@ -442,7 +644,7 @@ BytecodeGenerationPtr allocateString( BytecodeGenerationPtr tryGenerateNameForSymbol( BytecodeGenerationContext* context, - IRInst* inst) + IRGlobalValue* inst) { // TODO: this is gross, and the IR should probably have // a more direct means of querying a name for a symbol. @@ -462,9 +664,10 @@ BytecodeGenerationPtr tryGenerateNameForSymbol( return BytecodeGenerationPtr(); } +// Generate a `BCSymbol` that can represent a global value. BytecodeGenerationPtr generateBytecodeSymbolForInst( BytecodeGenerationContext* context, - IRInst* inst) + IRGlobalValue* inst) { switch( inst->op ) { @@ -474,7 +677,7 @@ BytecodeGenerationPtr generateBytecodeSymbolForInst( BytecodeGenerationPtr bcFunc = allocate(context); bcFunc->op = inst->op; - bcFunc->typeGlobalID = getTypeForGlobalSymbol(context, inst); + bcFunc->typeID = getTypeIDForGlobalSymbol(context, inst); BytecodeGenerationContext subContextStorage; BytecodeGenerationContext* subContext = &subContextStorage; @@ -515,7 +718,17 @@ BytecodeGenerationPtr generateBytecodeSymbolForInst( { UInt blockID = blockCounter++; UInt paramCount = 0; - for( auto ii = bb->firstChild; ii; ii = ii->nextInst ) + + for( auto pp = bb->getFirstParam(); pp; pp = pp->getNextParam() ) + { + // A parameter always uses a register. + regCounter++; + // + // We also want to keep a count of the parameters themselves. + paramCount++; + } + + for( auto ii = bb->getFirstInst(); ii; ii = ii->nextInst ) { switch( ii->op ) { @@ -528,15 +741,6 @@ BytecodeGenerationPtr generateBytecodeSymbolForInst( } break; - case kIROp_Param: - // A parameter always uses a register. - regCounter++; - // - // We also want to keep a count of the parameters themselves. - paramCount++; - break; - - case kIROp_Var: // A `var` (`alloca`) node needs two registers: // one to hold the actual storage, and another @@ -573,30 +777,25 @@ BytecodeGenerationPtr generateBytecodeSymbolForInst( // are always the first N registers in the overall list. // bcBlocks[blockID].params = bcRegs + regCounter; - for( auto ii = bb->firstChild; ii; ii = ii->nextInst ) + for( auto pp = bb->getFirstParam(); pp; pp = pp->getNextParam() ) { - if(ii->op != kIROp_Param) - continue; - Int localID = regCounter++; - subContext->mapInstToLocalID.Add(ii, localID); + subContext->mapInstToLocalID.Add(pp, localID); - bcRegs[localID].op = ii->op; - bcRegs[localID].name = tryGenerateNameForSymbol(context, ii); + bcRegs[localID].op = pp->op; +#if 0 + bcRegs[localID].name = tryGenerateNameForSymbol(context, pp); +#endif bcRegs[localID].previousVarIndexPlusOne = localID; - bcRegs[localID].typeGlobalID = getTypeForGlobalSymbol(context, ii); + bcRegs[localID].typeID = getTypeIDForGlobalSymbol(context, pp); } // Now loop over the non-parameter instructions and // allocate actual register locations to them. - for( auto ii = bb->firstChild; ii; ii = ii->nextInst ) + for( auto ii = bb->getFirstInst(); ii; ii = ii->nextInst ) { switch(ii->op) { - case kIROp_Param: - // Already handled. - break; - default: // For an ordinary instruction with a result, // allocate it here. @@ -606,9 +805,11 @@ BytecodeGenerationPtr generateBytecodeSymbolForInst( subContext->mapInstToLocalID.Add(ii, localID); bcRegs[localID].op = ii->op; +#if 0 bcRegs[localID].name = tryGenerateNameForSymbol(context, ii); +#endif bcRegs[localID].previousVarIndexPlusOne = localID; - bcRegs[localID].typeGlobalID = getTypeForGlobalSymbol(context, ii); + bcRegs[localID].typeID = getTypeIDForGlobalSymbol(context, ii); } break; @@ -624,30 +825,39 @@ BytecodeGenerationPtr generateBytecodeSymbolForInst( subContext->mapInstToLocalID.Add(ii, localID); bcRegs[localID].op = ii->op; +#if 0 bcRegs[localID].name = tryGenerateNameForSymbol(context, ii); +#endif bcRegs[localID].previousVarIndexPlusOne = localID; - bcRegs[localID].typeGlobalID = getTypeForGlobalSymbol(context, ii); + bcRegs[localID].typeID = getTypeIDForGlobalSymbol(context, ii); bcRegs[localID+1].op = ii->op; bcRegs[localID+1].previousVarIndexPlusOne = localID+1; - bcRegs[localID+1].typeGlobalID = getIDForGlobalSymbol(context, - ((IRPtrType*) ii->getType())->getValueType()); + bcRegs[localID+1].typeID = getTypeID(context, + (ii->getType()->As())->getValueType()); } break; } } } + assert(regCounter == regCount); // Now that we've allocated our blocks and our registers // we can go through the actual process of emitting instructions. Hooray! blockCounter = 0; + + // Offset of each basic block from the start of the code + // for the current funciton. + List blockOffsets; for( auto bb = irFunc->getFirstBlock(); bb; bb = bb->getNextBlock() ) { UInt blockID = blockCounter++; - bcBlocks[blockID].code = getPtr(context); + // Get local bytecode offset for current block. + UInt blockOffset = subContext->currentBytecode.Count(); + blockOffsets.Add( blockOffset ); - for( auto ii = bb->firstChild; ii; ii = ii->nextInst ) + for( auto ii = bb->getFirstInst(); ii; ii = ii->nextInst ) { // What we do with each instruction depends a bit on the // kind of instruction it is. @@ -671,6 +881,22 @@ BytecodeGenerationPtr generateBytecodeSymbolForInst( } } + + // We've collected bytecode for the instruction stream + // into a sub-context, so we can now append that code. + UInt byteCount = subContext->currentBytecode.Count(); + BytecodeGenerationPtr bytes = allocateArray(context, byteCount); + memcpy(&bytes[0], subContext->currentBytecode.Buffer(), byteCount); + + // Now that we've allocated the storage, we can write + // the bytecode pointers into the blocks. + blockCounter = 0; + for( auto bb = irFunc->getFirstBlock(); bb; bb = bb->getNextBlock() ) + { + UInt blockID = blockCounter++; + bcBlocks[blockID].code = bytes + blockOffsets[blockID]; + } + // Finally, after emitting all the instructions we can // build a table of global symbols taht need to be // imported into the current function as constants. @@ -689,6 +915,17 @@ BytecodeGenerationPtr generateBytecodeSymbolForInst( } break; + case kIROp_global_var: + { + auto bcVar = allocate(context); + + bcVar->op = inst->op; + bcVar->typeID = getTypeID(context, inst->type); + + return bcVar; + } + break; + default: // Most instructions don't need a custom representation. return BytecodeGenerationPtr(); @@ -699,126 +936,98 @@ BytecodeGenerationPtr generateBytecodeForModule( BytecodeGenerationContext* context, IRModule* irModule) { - // The module will get encoded much like a function, - // and then that function will be "invoked" to load - // the module. + // A module in the bytecode is mostly just a list of the + // global symbols in the module. // - auto bcModule = allocate(context); - bcModule->op = irModule->op; - bcModule->typeGlobalID = 0; - - // The logical function that the module representats - // will only have a single block, so we can allocate it now. + // TODO: we need to be careful and recognize the distinction + // between the global symbols in the *AST* module, vs. those + // symbols which are effectively global in the *IR* module. // - auto bcBlock = allocate(context); - bcBlock->paramCount = 0; - bcBlock->params = BytecodeGenerationPtr(); - - bcModule->blockCount = 1; - bcModule->blocks = bcBlock; - - // Because the module is the top-most level, there is - // no need for it to have "constants" that represent - // values imported from the next outer scope. + // We probably need to store these distinctly, since we + // need the AST global symbols for reflection, and then + // also to reconstruct the AST on load when importing a + // serialized module. We then need the global IR symbols + // to use when linking, to quickly resolve things without + // needing any semantic knowledge of nesting at the AST level. // - bcModule->constCount = 0; - bcModule->consts = BytecodeGenerationPtr(); + auto bcModule = allocate(context); // We need to compute how many "registers" to allocate // for the module, where the registers represent the // values being computed at the global scope. - UInt regCount = 0; - for( auto inst = irModule->firstChild; inst; inst = inst->nextInst ) + UInt symbolCount = 0; + for( auto gv : irModule->globalValues ) { - if(!opHasResult(inst)) - continue; - - Int globalID = Int(regCount++); + Int globalID = Int(symbolCount++); - context->shared->mapGlobalSymbolToGLobalID.Add(inst, globalID); + // Ensure that local code inside functions can see these symbols + BCConst bcConst; + bcConst.flavor = kBCConstFlavor_GlobalSymbol; + bcConst.id = globalID; + context->shared->mapValueToGlobal.Add(gv, bcConst); // In the global scope, global IDs are also the local IDs - context->mapInstToLocalID.Add(inst, globalID); + context->mapInstToLocalID.Add(gv, globalID); } - auto bcRegs = allocateArray(context, regCount); + auto bcSymbols = allocateArray>(context, symbolCount); - bcModule->regCount = regCount; - bcModule->regs = bcRegs; + bcModule->symbolCount = symbolCount; + bcModule->symbols = bcSymbols; - // Now lets walk through and initialize all those bytecode - // register representations so that we can use them. - UInt regCounter= 0; - for( auto inst = irModule->firstChild; inst; inst = inst->nextInst ) + for( auto gv : irModule->globalValues ) { - if(!opHasResult(inst)) - continue; - - UInt regIndex = *context->mapInstToLocalID.TryGetValue(inst); + UInt symbolIndex = *context->mapInstToLocalID.TryGetValue(gv); - BytecodeGenerationPtr name = tryGenerateNameForSymbol(context, inst); - - bcRegs[regIndex].op = inst->op; - bcRegs[regIndex].name = name; - bcRegs[regIndex].typeGlobalID = getTypeForGlobalSymbol(context, inst); - bcRegs[regIndex].previousVarIndexPlusOne = regIndex; - } - - // Some instructions represent "nested" symbols that will need - // custom handling, and we will represent those here. - List> nestedSymbols; - for( auto inst = irModule->firstChild; inst; inst = inst->nextInst ) - { - UInt regIndex = *context->mapInstToLocalID.TryGetValue(inst); - - auto bcSymbol = generateBytecodeSymbolForInst(context, inst); + auto bcSymbol = generateBytecodeSymbolForInst(context, gv); if (!bcSymbol.getPtr()) continue; - UInt nestedSymbolID = nestedSymbols.Count(); - nestedSymbols.Add(bcSymbol); - - context->mapInstToNestedID.Add(inst, nestedSymbolID); + auto name = tryGenerateNameForSymbol(context, gv); + bcSymbol->name = name; - bcSymbol->name = bcRegs[regIndex].name; + bcSymbols[symbolIndex] = bcSymbol; } - auto nestedSymbolCount = nestedSymbols.Count(); - auto bcNestedSymbols = allocateArray>(context, nestedSymbolCount); + // At this point we should have identified all the literals we need: + UInt constantCount = context->shared->constants.Count(); + auto bcConstants = allocateArray(context, constantCount); + bcModule->constantCount = constantCount; + bcModule->constants = bcConstants; - bcModule->nestedSymbolCount = nestedSymbolCount; - bcModule->nestedSymbols = bcNestedSymbols; - for (UInt ii = 0; ii < nestedSymbolCount; ++ii) + for(UInt cc = 0; cc < constantCount; ++cc) { - bcNestedSymbols[ii] = nestedSymbols[ii]; - } - + auto irConstant = (IRConstant*) context->shared->constants[cc]; + bcConstants[cc].op = irConstant->op; + bcConstants[cc].typeID = getTypeID(context, irConstant->type); - // Finally, we can go through and emit the actual code for - // the initialization step. - bcBlock->code = getPtr(context); - for( auto inst = irModule->firstChild; inst; inst = inst->nextInst ) - { - // Generate bytecode for global-scope inst - generateBytecodeForInst(context, inst); - } - // Need to encode a terminator here, just to keep the encoding valid - encodeUInt(context, kIROp_ReturnVoid); + switch(irConstant->op) + { + case kIROp_IntLit: + { + auto ptr = allocate(context); + *ptr = irConstant->u.intVal; + bcConstants[cc].ptr = ptr.bitCast(); + } + break; -#if 0 + default: + break; + } - // Now we can go through and generate the bytecode object - // that will represent each of these global symbols + } - List> globalSymbols; + // At this point we should have collected all the types we need: + UInt typeCount = context->shared->bcTypes.Count(); + auto bcTypes = allocateArray>(context, typeCount); + bcModule->typeCount = typeCount; + bcModule->types = bcTypes; - for( auto inst = irModule->firstChild; inst; inst = inst->nextInst ) + for(UInt tt = 0; tt < typeCount; ++tt) { - // Generate bytecode for global-scope inst - auto globalSymbol = generateBytecodeForGlobalSymbol(context, inst); - globalSymbols.Add(globalSymbol); + bcTypes[tt] = context->shared->bcTypes[tt]; } -#endif + return bcModule; } @@ -833,10 +1042,6 @@ void generateBytecodeStream( memcpy(header->magic, "slang\0bc", sizeof(header->magic)); header->version = 0; - // HACK: ensure that a NULL pointer in an operand field can - // be encoded. - context->shared->mapGlobalSymbolToGLobalID.Add(nullptr, -1); - header->module = generateBytecodeForModule(context, irModule); } diff --git a/source/slang/bytecode.h b/source/slang/bytecode.h index e073763cf..1ea16406f 100644 --- a/source/slang/bytecode.h +++ b/source/slang/bytecode.h @@ -65,6 +65,7 @@ struct BCPtr } operator T*() const { return getPtr(); } + T* operator->() const { return getPtr(); } T* getPtr() const { @@ -73,9 +74,52 @@ struct BCPtr } }; -struct BCType +// Representation of a "type-level" value in +// the bytecode fiel. This corresponds to +// the AST-level notion of a `Val` +struct BCVal { + // The opcode used to define this value uint32_t op; + + // The ID of the type within its module + uint32_t id; +}; + +struct BCType : BCVal +{ + // TODO: avoid having to encode this? + uint32_t argCount; + + // type-specific operands follow + + // + + BCPtr* getArgs() { return (BCPtr*) (this +1); } + + BCVal* getArg(UInt index) { return getArgs()[index]; } +}; + +struct BCPtrType : BCType +{ + BCPtr valueType; +}; + +struct BCFuncType : BCType +{ + BCPtr resultType; + BCPtr paramTypes[1]; + + BCType* getResultType() { return resultType; } + + UInt getParamCount() { return argCount - 1; } + BCType* getParamType(UInt index) { return paramTypes[index]; } +}; + +struct BCConstant : BCVal +{ + uint32_t typeID; + BCPtr ptr; }; struct BCSymbol @@ -84,15 +128,9 @@ struct BCSymbol // this symbol; used to categorize things uint32_t op; - // The type of the symbol is represent - // as an index into the global-scope symbol - // list of the module. - // - // Note: This currently precludes having - // a register with a type that is not - // statically determined, but that is - // probably okay. - uint32_t typeGlobalID; + // The index (in the module's type table) + // of the type of the symbol: + uint32_t typeID; // The name of this symbol (which might // be a mangled name at some point, @@ -111,17 +149,18 @@ struct BCReg : BCSymbol uint32_t previousVarIndexPlusOne; }; +enum BCConstFlavor +{ + kBCConstFlavor_GlobalSymbol, + kBCConstFlavor_Constant, +}; + struct BCConst { - // The ID of the symbol in the global - // scope that we are trying to refer - // to. - // - // TODO: eventually, if we have general - // nesting, then this might be the - // entry in the outer scope that - // is being referenced. - uint32_t globalID; + // The flavor of bytecode constant we + // are dealing with. + uint32_t flavor; + uint32_t id; }; struct BCBlock @@ -153,16 +192,27 @@ struct BCFunc : BCSymbol // but this would make the encoding less dense. uint32_t constCount; BCPtr consts; - - // Data for "nested" symbols (e.g., a function - // nested inside this function). - uint32_t nestedSymbolCount; - BCPtr> nestedSymbols; }; -// A module is encoded more or less like a function. -struct BCModule : BCFunc +struct BCModule { + // The symbols (functions, global variables, etc.) + // that have been declared in the module. + uint32_t symbolCount; + BCPtr> symbols; + + // The types that are used by this module, stored + // in a single array so that they can be conveniently + // mapped to another representation in one go. + // + // Instructions in a bytecode instruction sequence + // might reference these types by index. + uint32_t typeCount; + BCPtr> types; + + // True compile-time constants go here: + uint32_t constantCount; + BCPtr constants; }; struct BCHeader diff --git a/source/slang/check.cpp b/source/slang/check.cpp index 5c1f7380c..73d464d95 100644 --- a/source/slang/check.cpp +++ b/source/slang/check.cpp @@ -4546,26 +4546,21 @@ namespace Slang // if this is still an invoke expression, test arguments passed to inout/out parameter are LValues if(auto funcType = invoke->FunctionExpr->type->As()) { - List> paramsStorage; - List> * params = nullptr; - if (auto func = funcType->declRef.getDecl()) + UInt paramCount = funcType->getParamCount(); + for (UInt pp = 0; pp < paramCount; ++pp) { - paramsStorage = func->GetParameters().ToArray(); - params = ¶msStorage; - } - if (params) - { - for (UInt i = 0; i < (*params).Count(); i++) + auto paramType = funcType->getParamType(pp); + if (auto outParamType = paramType->As()) { - if ((*params)[i]->HasModifier()) + if (pp < expr->Arguments.Count() + && !expr->Arguments[pp]->type.IsLeftValue) { - if (i < expr->Arguments.Count() && expr->Arguments[i]->type->AsBasicType() && - !expr->Arguments[i]->type.IsLeftValue) + if (!isRewriteMode()) { - if (!isRewriteMode()) - { - getSink()->diagnose(expr->Arguments[i], Diagnostics::argumentExpectedLValue, (*params)[i]->getName()); - } + getSink()->diagnose( + expr->Arguments[pp], + Diagnostics::argumentExpectedLValue, + pp); } } } diff --git a/source/slang/compiler.h b/source/slang/compiler.h index 7ec812d63..1919699a1 100644 --- a/source/slang/compiler.h +++ b/source/slang/compiler.h @@ -15,6 +15,7 @@ namespace Slang struct IncludeHandler; class CompileRequest; class ProgramLayout; + class PtrType; enum class CompilerMode { @@ -369,6 +370,7 @@ namespace Slang RefPtr errorType; RefPtr initializerListType; RefPtr overloadedType; + RefPtr irBasicBlockType; Dictionary> builtinTypes; Dictionary magicDecls; @@ -388,6 +390,12 @@ namespace Slang Type* getOverloadedType(); Type* getErrorType(); + // Should not be used in front-end code + Type* getIRBasicBlockType(); + + // Construct pointer types on-demand + RefPtr getPtrType(RefPtr valueType); + SyntaxClass findSyntaxClass(Name* name); Dictionary > mapNameToSyntaxClass; diff --git a/source/slang/core.meta.slang b/source/slang/core.meta.slang index 2e200f63a..3755c51c7 100644 --- a/source/slang/core.meta.slang +++ b/source/slang/core.meta.slang @@ -85,6 +85,17 @@ for (int tt = 0; tt < kBaseTypeCount; ++tt) sb << "};\n"; } +// Declare built-in pointer type +// (eventually we can have the traditional syntax sugar for this) + +}}}} + +__generic +__magic_type(PtrType) +struct Ptr +{}; + +${{{{ // Declare vector and matrix types diff --git a/source/slang/core.meta.slang.h b/source/slang/core.meta.slang.h index cf2052d3c..864196944 100644 --- a/source/slang/core.meta.slang.h +++ b/source/slang/core.meta.slang.h @@ -86,6 +86,18 @@ for (int tt = 0; tt < kBaseTypeCount; ++tt) sb << "};\n"; } +// Declare built-in pointer type +// (eventually we can have the traditional syntax sugar for this) + +sb << "\n"; +sb << "\n"; +sb << "__generic\n"; +sb << "__magic_type(PtrType)\n"; +sb << "struct Ptr\n"; +sb << "{};\n"; +sb << "\n"; +sb << ""; + // Declare vector and matrix types diff --git a/source/slang/emit.cpp b/source/slang/emit.cpp index d4c1be706..64ac85bd4 100644 --- a/source/slang/emit.cpp +++ b/source/slang/emit.cpp @@ -93,6 +93,10 @@ struct SharedEmitContext bool needHackSamplerForTexelFetch = false; ExtensionUsageTracker extensionUsageTracker; + + Dictionary mapIRValueToID; + + HashSet irDeclsVisited; }; struct EmitContext @@ -1025,6 +1029,9 @@ struct EmitVisitor UNEXPECTED(GenericDeclRefType); UNEXPECTED(InitializerListType); + UNEXPECTED(IRBasicBlockType); + UNEXPECTED(PtrType); + #undef UNEXPECTED void visitNamedExpressionType(NamedExpressionType* type, TypeEmitArg const& arg) @@ -1231,6 +1238,17 @@ struct EmitVisitor EmitType(type, SourceLoc(), name, SourceLoc()); } + void EmitType(RefPtr type, String const& name) + { + // HACK: the rest of the code wants a `Name`, + // so we'll create one for a bit... + Name tempName; + tempName.text = name; + + EmitType(type, SourceLoc(), &tempName, SourceLoc()); + } + + void EmitType(RefPtr type) { emitTypeImpl(type, nullptr); @@ -3942,8 +3960,34 @@ emitDeclImpl(decl, nullptr); // IR-level emit logc - String getName(IRInst* inst) + UInt getID(IRValue* value) + { + auto& mapIRValueToID = context->shared->mapIRValueToID; + + UInt id = 0; + if (mapIRValueToID.TryGetValue(value, id)) + return id; + + id = mapIRValueToID.Count() + 1; + mapIRValueToID.Add(value, id); + return id; + } + + String getName(IRValue* inst) { + switch(inst->op) + { + case kIROp_decl_ref: + { + auto irDeclRef = (IRDeclRef*) inst; + return getText(irDeclRef->declRef.GetName()); + } + break; + + default: + break; + } + if(auto decoration = inst->findDecoration()) { auto decl = decoration->decl; @@ -3957,7 +4001,7 @@ emitDeclImpl(decl, nullptr); StringBuilder sb; sb << "_S"; - sb << inst->id; + sb << getID(inst); return sb.ProduceString(); } @@ -4035,7 +4079,7 @@ emitDeclImpl(decl, nullptr); } - +#if 0 void emitIRSimpleType( EmitContext* context, IRType* type) @@ -4153,12 +4197,14 @@ emitDeclImpl(decl, nullptr); } } +#endif CodeGenTarget getTarget(EmitContext* context) { return context->shared->target; } +#if 0 void emitGLSLTypePrefix( EmitContext* context, IRType* type) @@ -4194,44 +4240,9 @@ emitDeclImpl(decl, nullptr); break; } } +#endif - void emitIRVectorType( - EmitContext* context, - IRVectorType* type) - { - switch(getTarget(context)) - { - case CodeGenTarget::GLSL: - // HLSL style: `vec` - // e.g., `ivec4` - // - emitGLSLTypePrefix(context, type->getElementType()); - emit("vec"); - emitIRSimpleValue(context, type->getElementCount()); - break; - - default: - // HLSL style: `` - // e.g., `int4` - // - emitIRSimpleType(context, type->getElementType()); - emitIRSimpleValue(context, type->getElementCount()); - break; - } - } - - void emitIRMatrixType( - EmitContext* context, - IRMatrixType* type) - { - // TODO: this is a GLSL-vs-HLSL decision point - - emitIRSimpleType(context, type->getElementType()); - emitIRSimpleValue(context, type->getRowCount()); - emit("x"); - emitIRSimpleValue(context, type->getColumnCount()); - } - +#if 0 void emitIRType( EmitContext* context, IRType* type, @@ -4287,10 +4298,11 @@ emitDeclImpl(decl, nullptr); { emitIRType(context, type, (IRDeclaratorInfo*) nullptr); } +#endif bool shouldFoldIRInstIntoUseSites( EmitContext* context, - IRInst* inst) + IRValue* inst) { // Certain opcodes should always be folded in switch( inst->op ) @@ -4306,27 +4318,24 @@ emitDeclImpl(decl, nullptr); return true; } - // Certain *types* will usually want to be folded in + // Certain *types* will usually want to be folded in, + // because they aren't allowed as types for temporary + // variables. auto type = inst->getType(); - switch (type->op) + if(type->As()) { - case kIROp_ConstantBufferType: - case kIROp_TextureBufferType: // TODO: we need to be careful here, because // HLSL shader model 6 allows these as explicit // types. return true; - - case kIROp_TextureType: + } + else if(type->As()) + { // GLSL doesn't allow texture/resource types to // be used as first-class values, so we need // to fold them into their use sites in all cases if(getTarget(context) == CodeGenTarget::GLSL) return true; - break; - - default: - break; } // By default we will *not* fold things into their use sites. @@ -4335,20 +4344,16 @@ emitDeclImpl(decl, nullptr); bool isDerefBaseImplicit( EmitContext* context, - IRInst* inst) + IRValue* inst) { auto type = inst->getType(); - switch (type->op) + + if(type->As()) { - case kIROp_ConstantBufferType: - case kIROp_TextureBufferType: // TODO: we need to be careful here, because // HLSL shader model 6 allows these as explicit // types. return true; - - default: - break; } return false; @@ -4358,7 +4363,7 @@ emitDeclImpl(decl, nullptr); void emitIROperand( EmitContext* context, - IRInst* inst) + IRValue* inst) { if( shouldFoldIRInstIntoUseSites(context, inst) ) { @@ -4380,8 +4385,8 @@ emitDeclImpl(decl, nullptr); EmitContext* context, IRInst* inst) { - UInt argCount = inst->argCount - 1; - IRUse* args = inst->getArgs() + 1; + UInt argCount = inst->argCount; + IRUse* args = inst->getArgs(); emit("("); for(UInt aa = 0; aa < argCount; ++aa) @@ -4392,12 +4397,35 @@ emitDeclImpl(decl, nullptr); emit(")"); } + void emitIRType( + EmitContext* context, + IRType* type, + String const& name) + { + EmitType(type, name); + } + + void emitIRType( + EmitContext* context, + IRType* type, + Name* name) + { + EmitType(type, name); + } + + void emitIRType( + EmitContext* context, + IRType* type) + { + EmitType(type); + } + void emitIRInstResultDecl( EmitContext* context, IRInst* inst) { auto type = inst->getType(); - if(!type || type->op == kIROp_VoidType) + if(!type) return; emitIRType(context, type, getName(inst)); @@ -4406,9 +4434,10 @@ emitDeclImpl(decl, nullptr); void emitIRInstExpr( EmitContext* context, - IRInst* inst) + IRValue* value) { - switch(inst->op) + IRInst* inst = (IRInst*) value; + switch(value->op) { case kIROp_IntLit: case kIROp_FloatLit: @@ -4418,13 +4447,13 @@ emitDeclImpl(decl, nullptr); case kIROp_Construct: // Simple constructor call - if( inst->getArgCount() == 2 && getTarget(context) == CodeGenTarget::HLSL) + if( inst->getArgCount() == 1 && getTarget(context) == CodeGenTarget::HLSL) { // Need to emit as cast for HLSL emit("("); emitIRType(context, inst->getType()); emit(") "); - emitIROperand(context, inst->getArg(1)); + emitIROperand(context, inst->getArg(0)); } else { @@ -4446,7 +4475,7 @@ emitDeclImpl(decl, nullptr); emitIRType(context, inst->getType()); } emit("("); - emitIROperand(context, inst->getArg(1)); + emitIROperand(context, inst->getArg(0)); emit(")"); break; @@ -4480,9 +4509,9 @@ emitDeclImpl(decl, nullptr); #define CASE(OPCODE, OP) \ case OPCODE: \ - emitIROperand(context, inst->getArg(1)); \ + emitIROperand(context, inst->getArg(0)); \ emit("" #OP " "); \ - emitIROperand(context, inst->getArg(2)); \ + emitIROperand(context, inst->getArg(1)); \ break CASE(kIROp_Add, +); @@ -4512,7 +4541,7 @@ emitDeclImpl(decl, nullptr); case kIROp_Not: { - if (inst->getType()->op == kIROp_BoolType) + if (inst->getType()->Equals(getSession()->getBoolType())) { emit("!"); } @@ -4520,73 +4549,72 @@ emitDeclImpl(decl, nullptr); { emit("~"); } - emitIROperand(context, inst->getArg(1)); + emitIROperand(context, inst->getArg(0)); } break; case kIROp_Sample: - // argument 0 is the instruction's type - emitIROperand(context, inst->getArg(1)); + emitIROperand(context, inst->getArg(0)); emit(".Sample("); - emitIROperand(context, inst->getArg(2)); + emitIROperand(context, inst->getArg(1)); emit(", "); - emitIROperand(context, inst->getArg(3)); + emitIROperand(context, inst->getArg(2)); emit(")"); break; case kIROp_SampleGrad: // argument 0 is the instruction's type - emitIROperand(context, inst->getArg(1)); + emitIROperand(context, inst->getArg(0)); emit(".SampleGrad("); + emitIROperand(context, inst->getArg(1)); + emit(", "); emitIROperand(context, inst->getArg(2)); emit(", "); emitIROperand(context, inst->getArg(3)); emit(", "); emitIROperand(context, inst->getArg(4)); - emit(", "); - emitIROperand(context, inst->getArg(5)); emit(")"); break; case kIROp_Load: // TODO: this logic will really only work for a simple variable reference... - emitIROperand(context, inst->getArg(1)); + emitIROperand(context, inst->getArg(0)); break; case kIROp_Store: // TODO: this logic will really only work for a simple variable reference... - emitIROperand(context, inst->getArg(1)); + emitIROperand(context, inst->getArg(0)); emit(" = "); - emitIROperand(context, inst->getArg(2)); + emitIROperand(context, inst->getArg(1)); break; case kIROp_Call: { - emitIROperand(context, inst->getArg(1)); + emitIROperand(context, inst->getArg(0)); emit("("); - UInt argCount = inst->getArgCount() - 2; + UInt argCount = inst->getArgCount() - 1; for( UInt aa = 0; aa < argCount; ++aa ) { if(aa != 0) emit(", "); - emitIROperand(context, inst->getArg(aa + 2)); + emitIROperand(context, inst->getArg(aa + 1)); } emit(")"); } break; case kIROp_BufferLoad: - emitIROperand(context, inst->getArg(1)); + emitIROperand(context, inst->getArg(0)); emit("["); - emitIROperand(context, inst->getArg(2)); + emitIROperand(context, inst->getArg(1)); emit("]"); break; case kIROp_BufferStore: - emitIROperand(context, inst->getArg(1)); + emitIROperand(context, inst->getArg(0)); emit("["); - emitIROperand(context, inst->getArg(2)); + emitIROperand(context, inst->getArg(1)); emit("] = "); - emitIROperand(context, inst->getArg(3)); + emitIROperand(context, inst->getArg(2)); break; case kIROp_GroupMemoryBarrierWithGroupSync: @@ -4595,9 +4623,9 @@ emitDeclImpl(decl, nullptr); case kIROp_getElement: case kIROp_getElementPtr: - emitIROperand(context, inst->getArg(1)); + emitIROperand(context, inst->getArg(0)); emit("["); - emitIROperand(context, inst->getArg(2)); + emitIROperand(context, inst->getArg(1)); emit("]"); break; @@ -4605,9 +4633,9 @@ emitDeclImpl(decl, nullptr); case kIROp_Mul_Matrix_Vector: case kIROp_Mul_Matrix_Matrix: emit("mul("); - emitIROperand(context, inst->getArg(1)); + emitIROperand(context, inst->getArg(0)); emit(", "); - emitIROperand(context, inst->getArg(2)); + emitIROperand(context, inst->getArg(1)); emit(")"); break; @@ -4619,7 +4647,7 @@ emitDeclImpl(decl, nullptr); UInt elementCount = ii->getElementCount(); for (UInt ee = 0; ee < elementCount; ++ee) { - IRInst* irElementIndex = ii->getElementIndex(ee); + IRValue* irElementIndex = ii->getElementIndex(ee); assert(irElementIndex->op == kIROp_IntLit); IRConstant* irConst = (IRConstant*)irElementIndex; @@ -4658,7 +4686,7 @@ emitDeclImpl(decl, nullptr); case kIROp_Var: { auto ptrType = inst->getType(); - auto valType = ((IRPtrType*)ptrType)->getValueType(); + auto valType = ((PtrType*)ptrType)->getValueType(); auto name = getName(inst); emitIRType(context, valType, name); @@ -4689,14 +4717,14 @@ emitDeclImpl(decl, nullptr); { auto ii = (IRSwizzleSet*)inst; emitIRInstResultDecl(context, inst); - emitIROperand(context, inst->getArg(1)); + emitIROperand(context, inst->getArg(0)); emit(";\n"); emitIROperand(context, inst); emit("."); UInt elementCount = ii->getElementCount(); for (UInt ee = 0; ee < elementCount; ++ee) { - IRInst* irElementIndex = ii->getElementIndex(ee); + IRValue* irElementIndex = ii->getElementIndex(ee); assert(irElementIndex->op == kIROp_IntLit); IRConstant* irConst = (IRConstant*)irElementIndex; @@ -4707,34 +4735,16 @@ emitDeclImpl(decl, nullptr); emit(kComponents[elementIndex]); } emit(" = "); - emitIROperand(context, inst->getArg(2)); + emitIROperand(context, inst->getArg(1)); emit(";\n"); } break; - -#if 0 - case kIROp_unconditionalBranch: - emit("// unconditionalBranch "); - emitIROperand(context, inst->getArg(1)); - emit("\n"); - break; - - case kIROp_conditionalBranch: - emit("// conditionalBranch "); - emitIROperand(context, inst->getArg(1)); - emit(" "); - emitIROperand(context, inst->getArg(2)); - emit(" "); - emitIROperand(context, inst->getArg(3)); - emit("\n"); - break; -#endif } } void emitIRSemantics( EmitContext* context, - IRInst* inst) + IRValue* inst) { auto decoration = inst->findDecoration(); if( decoration ) @@ -4745,7 +4755,7 @@ emitDeclImpl(decl, nullptr); VarLayout* getVarLayout( EmitContext* context, - IRInst* var) + IRValue* var) { auto decoration = var->findDecoration(); if (!decoration) @@ -4756,7 +4766,7 @@ emitDeclImpl(decl, nullptr); void emitIRLayoutSemantics( EmitContext* context, - IRInst* inst, + IRValue* inst, char const* uniformSemanticSpelling = "register") { auto layout = getVarLayout(context, inst); @@ -4784,9 +4794,9 @@ emitDeclImpl(decl, nullptr); while(block != end) { // Start by emitting the non-terminator instructions in the block. - auto terminator = block->lastChild; + auto terminator = block->getLastInst(); assert(isTerminatorInst(terminator)); - for (auto inst = block->firstChild; inst != terminator; inst = inst->nextInst) + for (auto inst = block->getFirstInst(); inst != terminator; inst = inst->nextInst) { emitIRInst(context, inst); } @@ -5095,10 +5105,8 @@ emitDeclImpl(decl, nullptr); // encoded as a parameter of pointer type, so // we need to decode that here. // - if( paramType->op == kIROp_PtrType ) + if( auto ptrType = paramType->As() ) { - auto ptrType = (IRPtrType*) paramType; - // TODO: we need a way to distinguish `out` // from `inout`. The easiest way to do // that might be to have each be a distinct @@ -5243,7 +5251,7 @@ emitDeclImpl(decl, nullptr); } } - +#if 0 void emitIRStruct( EmitContext* context, IRStructDecl* structType) @@ -5263,6 +5271,7 @@ emitDeclImpl(decl, nullptr); } emit("};\n"); } +#endif void emitIRVarModifiers( EmitContext* context, @@ -5317,9 +5326,9 @@ emitDeclImpl(decl, nullptr); } void emitIRParameterBlock( - EmitContext* context, - IRVar* varDecl, - IRUniformBufferType* type) + EmitContext* context, + IRGlobalVar* varDecl, + UniformParameterBlockType* type) { emit("cbuffer "); emit(getName(varDecl)); @@ -5341,17 +5350,15 @@ emitDeclImpl(decl, nullptr); typeLayout = parameterBlockTypeLayout->elementTypeLayout; } - switch( elementType->op ) + if(auto declRefType = elementType->As()) { - case kIROp_StructType: + if(auto structDeclRef = declRefType->declRef.As()) { - auto structType = (IRStructDecl*) elementType; - auto structTypeLayout = typeLayout.As(); assert(structTypeLayout); UInt fieldIndex = 0; - for(auto ff = structType->getFirstField(); ff; ff = ff->getNextField()) + for(auto ff : GetFields(structDeclRef)) { // TODO: need a plan to deal with the case where the IR-level // `struct` type might not match the high-level type, so that @@ -5367,19 +5374,18 @@ emitDeclImpl(decl, nullptr); emitIRVarModifiers(context, fieldLayout); - auto fieldType = ff->getFieldType(); - emitIRType(context, fieldType, getName(ff)); + auto fieldType = GetType(ff); + emitIRType(context, fieldType, ff.GetName()); emitHLSLParameterBlockFieldLayoutSemantics(layout, fieldLayout); emit(";\n"); } } - break; - - default: + } + else + { emit("/* unexpected */"); - break; } emit("}\n"); @@ -5391,8 +5397,9 @@ emitDeclImpl(decl, nullptr); { auto allocatedType = varDecl->getType(); auto varType = allocatedType->getValueType(); - auto addressSpace = allocatedType->getAddressSpace(); +// auto addressSpace = allocatedType->getAddressSpace(); +#if 0 switch( varType->op ) { case kIROp_ConstantBufferType: @@ -5403,6 +5410,7 @@ emitDeclImpl(decl, nullptr); default: break; } +#endif // Need to emit appropriate modifiers here. @@ -5410,6 +5418,7 @@ emitDeclImpl(decl, nullptr); emitIRVarModifiers(context, layout); +#if 0 switch (addressSpace) { default: @@ -5419,6 +5428,51 @@ emitDeclImpl(decl, nullptr); emit("groupshared "); break; } +#endif + + emitIRType(context, varType, getName(varDecl)); + + emitIRSemantics(context, varDecl); + + emitIRLayoutSemantics(context, varDecl); + + emit(";\n"); + } + + void emitIRGlobalVar( + EmitContext* context, + IRGlobalVar* varDecl) + { + auto allocatedType = varDecl->getType(); + auto varType = allocatedType->getValueType(); +// auto addressSpace = allocatedType->getAddressSpace(); + + if (auto paramBlockType = varType->As()) + { + emitIRParameterBlock( + context, + varDecl, + paramBlockType); + return; + } + + // Need to emit appropriate modifiers here. + + auto layout = getVarLayout(context, varDecl); + + emitIRVarModifiers(context, layout); + +#if 0 + switch (addressSpace) + { + default: + break; + + case kIRAddressSpace_GroupShared: + emit("groupshared "); + break; + } +#endif emitIRType(context, varType, getName(varDecl)); @@ -5431,7 +5485,7 @@ emitDeclImpl(decl, nullptr); void emitIRGlobalInst( EmitContext* context, - IRInst* inst) + IRGlobalValue* inst) { // TODO: need to be able to `switch` on the IR opcode here, // so there is some work to be done. @@ -5441,9 +5495,15 @@ emitDeclImpl(decl, nullptr); emitIRFunc(context, (IRFunc*) inst); break; + case kIROp_global_var: + emitIRGlobalVar(context, (IRGlobalVar*) inst); + break; + +#if 0 case kIROp_StructType: emitIRStruct(context, (IRStructDecl*) inst); break; +#endif case kIROp_Var: emitIRVar(context, (IRVar*) inst); @@ -5454,13 +5514,154 @@ emitDeclImpl(decl, nullptr); } } + void ensureStructDecl( + EmitContext* context, + DeclRef declRef) + { + // TODO: Eventually need to deal with the case where + // we have user-defined generic types. + // + auto decl = declRef.getDecl(); + + if(context->shared->irDeclsVisited.Contains(decl)) + return; + + context->shared->irDeclsVisited.Add(decl); + + // First emit any types used by fields of this type + for( auto ff : GetFields(declRef) ) + { + if(ff.getDecl()->HasModifier()) + continue; + + auto fieldType = GetType(ff); + emitIRUsedType(context, fieldType); + } + + Emit("struct "); + emit(declRef.GetName()); + Emit("\n{\n"); + for( auto ff : GetFields(declRef) ) + { + if(ff.getDecl()->HasModifier()) + continue; + + auto fieldType = GetType(ff); + emitIRType(context, fieldType, ff.GetName()); + + EmitSemantics(ff.getDecl()); + + emit(";\n"); + } + Emit("};\n"); + } + + // A type is going to be used by the IR, so + // make sure that we have emitted whatever + // it needs. + void emitIRUsedType( + EmitContext* context, + Type* type) + { + if(type->As()) + {} + else if(type->As()) + {} + else if(type->As()) + {} + else if(auto arrayType = type->As()) + { + emitIRUsedType(context, arrayType->baseType); + } + else if( auto textureType = type->As() ) + { + emitIRUsedType(context, textureType->elementType); + } + else if( auto genericType = type->As() ) + { + emitIRUsedType(context, genericType->elementType); + } + else if( auto ptrType = type->As() ) + { + emitIRUsedType(context, ptrType->getValueType()); + } + else if(type->As() ) + { + } + else if( auto declRefType = type->As() ) + { + auto declRef = declRefType->declRef; + auto decl = declRef.getDecl(); + + if(decl->HasModifier() + || decl->HasModifier()) + { + return; + } + + if( auto structDeclRef = declRef.As() ) + { + // + ensureStructDecl(context, structDeclRef); + } + } + else + {} + } + + void emitIRUsedTypesForValue( + EmitContext* context, + IRValue* value) + { + if(!value) return; + switch( value->op ) + { + case kIROp_Func: + { + auto irFunc = (IRFunc*) value; + emitIRUsedType(context, irFunc->getResultType()); + for( auto bb = irFunc->getFirstBlock(); bb; bb = bb->getNextBlock() ) + { + for( auto pp = bb->getFirstParam(); pp; pp = pp->getNextParam() ) + { + emitIRUsedTypesForValue(context, pp); + } + + for( auto ii = bb->getFirstInst(); ii; ii = ii->nextInst ) + { + emitIRUsedTypesForValue(context, ii); + } + } + } + break; + + default: + { + emitIRUsedType(context, value->type); + } + break; + } + } + + void emitIRUsedTypesForModule( + EmitContext* context, + IRModule* module) + { + for (auto gv : module->globalValues) + { + emitIRUsedTypesForValue(context, gv); + } + } + void emitIRModule( EmitContext* context, IRModule* module) { - for(auto ii = module->firstChild; ii; ii = ii->nextInst ) + emitIRUsedTypesForModule(context, module); + + for (auto gv : module->globalValues) { - emitIRGlobalInst(context, ii); + emitIRGlobalInst(context, gv); } } diff --git a/source/slang/ir-inst-defs.h b/source/slang/ir-inst-defs.h index 60dc353ac..c11d66571 100644 --- a/source/slang/ir-inst-defs.h +++ b/source/slang/ir-inst-defs.h @@ -22,20 +22,20 @@ INST(arrayType, Array, 2, 0) INST(BoolType, Bool, 0, 0) -INST(Float16Type, Float16, 0, 0) -INST(Float32Type, Float32, 0, 0) -INST(Float64Type, Float64, 0, 0) +INST(Float16Type, Float16, 0, 0) +INST(Float32Type, Float32, 0, 0) +INST(Float64Type, Float64, 0, 0) // Signed integer types. // Note that `IntPtr` represents a pointer-sized integer type, // and will end up being equivalent to either `Int32` or `Int64` // when it comes time to actually generate code. // -INST(Int8Type, Int8, 0, 0) -INST(Int16Type, Int16, 0, 0) -INST(Int32Type, Int32, 0, 0) -INST(IntPtrType, IntPtr, 0, 0) -INST(Int64Type, Int64, 0, 0) +INST(Int8Type, Int8, 0, 0) +INST(Int16Type, Int16, 0, 0) +INST(Int32Type, Int32, 0, 0) +INST(IntPtrType, IntPtr, 0, 0) +INST(Int64Type, Int64, 0, 0) // Unlike a lot of other IRs, we retain a distinction between // signed and unsigned integer types, simply because many of @@ -52,11 +52,11 @@ INST(Int64Type, Int64, 0, 0) // or else we want to keep using the orignal types, but need // to cast around any ordinary math operations on signed types. // -INST(UInt8Type, Int8, 0, 0) -INST(UInt16Type, Int16, 0, 0) -INST(UInt32Type, Int32, 0, 0) -INST(UIntPtrType, IntPtr, 0, 0) -INST(UInt64Type, Int64, 0, 0) +INST(UInt8Type, Int8, 0, 0) +INST(UInt16Type, Int16, 0, 0) +INST(UInt32Type, Int32, 0, 0) +INST(UIntPtrType, IntPtr, 0, 0) +INST(UInt64Type, Int64, 0, 0) // A user-defined structure declaration at the IR level. // Unlike in the AST where there is a distinction between @@ -94,6 +94,7 @@ INST(GenericParameterType, GenericParameterType, 1, 0) INST(boolConst, boolConst, 0, 0) INST(IntLit, integer_constant, 0, 0) INST(FloatLit, float_constant, 0, 0) +INST(decl_ref, decl_ref, 0, 0) INST(Construct, construct, 0, 0) INST(Call, call, 1, 0) @@ -102,6 +103,8 @@ INST(Module, module, 0, PARENT) INST(Func, func, 0, PARENT) INST(Block, block, 0, PARENT) +INST(global_var, global_var, 0, 0) + INST(Param, param, 0, 0) INST(StructField, field, 0, 0) INST(Var, var, 0, 0) diff --git a/source/slang/ir-insts.h b/source/slang/ir-insts.h index 1e9638ade..74151adaa 100644 --- a/source/slang/ir-insts.h +++ b/source/slang/ir-insts.h @@ -10,6 +10,7 @@ #include "compiler.h" #include "ir.h" +#include "syntax.h" namespace Slang { @@ -62,65 +63,22 @@ struct IRLoopControlDecoration : IRDecoration IRLoopControl mode; }; -struct IRMangledNameDecoration : IRDecoration -{ - enum { kDecorationOp = kIRDecorationOp_MangledName }; - - String mangledName; -}; - -// Types - -struct IRVectorType : IRType -{ - IRUse elementType; - IRUse elementCount; - - IRType* getElementType() { return (IRType*) elementType.usedValue; } - IRInst* getElementCount() { return elementCount.usedValue; } -}; - -struct IRMatrixType : IRType -{ - IRUse elementType; - IRUse rowCount; - IRUse columnCount; - - IRType* getElementType() { return (IRType*) elementType.usedValue; } - IRInst* getRowCount() { return rowCount.usedValue; } - IRInst* getColumnCount() { return columnCount.usedValue; } -}; - -struct IRArrayType : IRType -{ - IRUse elementType; - IRUse elementCount; - - IRType* getElementType() { return (IRType*) elementType.usedValue; } - IRInst* getElementCount() { return elementCount.usedValue; } -}; +// -struct IRGenericParameterType : IRType +// An IR node to represent a reference to an AST-level +// declaration. +struct IRDeclRef : IRValue { - IRUse index; + DeclRefBase declRef; }; +// -struct IRFuncType : IRType -{ - IRUse resultType; - // parameter tyeps are varargs... - - IRType* getResultType() { return (IRType*) resultType.usedValue; } - UInt getParamCount() - { - return getArgCount() - 2; - } - IRType* getParamType(UInt index) - { - return (IRType*) getArg(2 + index); - } -}; +// TODO(tfoley): the IR-level representation of +// pointers had an "address space" field that the +// current AST level lacks. This capability was +// used to represent `groupshared` allocation, +// so we probably need to find an alternative. // Address spaces for IR pointers enum IRAddressSpace : UInt @@ -132,6 +90,7 @@ enum IRAddressSpace : UInt kIRAddressSpace_GroupShared, }; +#if 0 struct IRPtrType : IRType { IRUse valueType; @@ -145,31 +104,7 @@ struct IRPtrType : IRType ((IRConstant*)addressSpace.usedValue)->u.intVal); } }; - -struct IRTextureType : IRType -{ - IRUse flavor; - IRUse elementType; - - IRIntegerValue getFlavor() { return ((IRConstant*) flavor.usedValue)->u.intVal; } - IRType* getElementType() { return (IRType*) elementType.usedValue; } -}; - -struct IRBufferType : IRType -{ - IRUse elementType; - IRType* getElementType() { return (IRType*) elementType.usedValue; } -}; - - -struct IRUniformBufferType : IRType -{ - IRUse elementType; - IRType* getElementType() { return (IRType*) elementType.usedValue; } -}; - -struct IRConstantBufferType : IRUniformBufferType {}; -struct IRTextureBufferType : IRUniformBufferType {}; +#endif struct IRCall : IRInst { @@ -187,14 +122,13 @@ struct IRStore : IRInst IRUse val; }; -struct IRStructField; struct IRFieldExtract : IRInst { IRUse base; IRUse field; - IRInst* getBase() { return base.usedValue; } - IRStructField* getField() { return (IRStructField*) field.usedValue; } + IRValue* getBase() { return base.usedValue; } + IRValue* getField() { return field.usedValue; } }; struct IRFieldAddress : IRInst @@ -202,8 +136,8 @@ struct IRFieldAddress : IRInst IRUse base; IRUse field; - IRInst* getBase() { return base.usedValue; } - IRStructField* getField() { return (IRStructField*) field.usedValue; } + IRValue* getBase() { return base.usedValue; } + IRValue* getField() { return field.usedValue; } }; // Terminators @@ -215,7 +149,7 @@ struct IRReturnVal : IRReturn { IRUse val; - IRInst* getVal() { return val.usedValue; } + IRValue* getVal() { return val.usedValue; } }; struct IRReturnVoid : IRReturn @@ -260,7 +194,7 @@ struct IRConditionalBranch : IRTerminatorInst IRUse trueBlock; IRUse falseBlock; - IRInst* getCondition() { return condition.usedValue; } + IRValue* getCondition() { return condition.usedValue; } IRBlock* getTrueBlock() { return (IRBlock*)trueBlock.usedValue; } IRBlock* getFalseBlock() { return (IRBlock*)falseBlock.usedValue; } }; @@ -296,12 +230,12 @@ struct IRSwizzle : IRReturn { IRUse base; - IRInst* getBase() { return base.usedValue; } + IRValue* getBase() { return base.usedValue; } UInt getElementCount() { return getArgCount() - 2; } - IRInst* getElementIndex(UInt index) + IRValue* getElementIndex(UInt index) { return getArg(index + 2); } @@ -312,43 +246,43 @@ struct IRSwizzleSet : IRReturn IRUse base; IRUse source; - IRInst* getBase() { return base.usedValue; } - IRInst* getSource() { return source.usedValue; } + IRValue* getBase() { return base.usedValue; } + IRValue* getSource() { return source.usedValue; } UInt getElementCount() { return getArgCount() - 3; } - IRInst* getElementIndex(UInt index) + IRValue* getElementIndex(UInt index) { return getArg(index + 3); } }; -// "Parent" Instructions (Declarations) - -struct IRStructField : IRInst +// An IR `var` instruction conceptually represents +// a stack allocation of some memory. +struct IRVar : IRInst { - IRType* getFieldType() { return (IRType*) type.usedValue; } - - IRStructField* getNextField() { return (IRStructField*) nextInst; } + PtrType* getType() + { + return (PtrType*)type.Ptr(); + } }; -struct IRStructDecl : IRParentInst +struct IRGlobalVar : IRGlobalValue { - IRStructField* getFirstField() { return (IRStructField*) firstChild; } - IRStructField* getLastField() { return (IRStructField*) lastChild; } -}; + // TODO: should contain information + // for use in initializing the variable + // (e.g., a reference to a function + // that is to be evaluated to provide + // the initial value, or a basic block + // that defines a DAG of constant + // values to use as initial values...) - -struct IRVar : IRInst -{ - IRPtrType* getType() - { - return (IRPtrType*)type.usedValue; - } + PtrType* getType() { return type.As(); } }; + // Description of an instruction to be used for global value numbering struct IRInstKey { @@ -369,6 +303,13 @@ bool operator==(IRConstantKey const& left, IRConstantKey const& right); struct SharedIRBuilder { + // The parent compilation session + Session* session; + Session* getSession() + { + return session; + } + // The module that will own all of the IR IRModule* module; @@ -381,53 +322,36 @@ struct IRBuilder // Shared state for all IR builders working on the same module SharedIRBuilder* shared; - IRModule* getModule() { return shared->module; } - - // The parent instruction to add children to. - IRParentInst* parentInst; - - void addInst(IRParentInst* parent, IRInst* inst); - void addInst(IRInst* inst); - - IRType* getBaseType(BaseType flavor); - IRType* getBoolType(); - IRType* getVectorType(IRType* elementType, IRValue* elementCount); - IRType* getMatrixType( - IRType* elementType, - IRValue* rowCount, - IRValue* columnCount); - IRType* getArrayType(IRType* elementType, IRValue* elementCount); - IRType* getArrayType(IRType* elementType); - IRType* getGenericParameterType(UInt index); - - IRType* getTypeType(); - IRType* getVoidType(); - IRType* getBlockType(); - - IRType* getIntrinsicType( - IROp op, - UInt argCount, - IRValue* const* args); + Session* getSession() + { + return shared->getSession(); + } - IRStructDecl* createStructType(); - IRStructField* createStructField(IRType* fieldType); + IRModule* getModule() { return shared->module; } - IRType* getFuncType( - UInt paramCount, - IRType* const* paramTypes, - IRType* resultType); + // The current function and block being inserted into + // (or `null` if we aren't inserting). + IRFunc* func = nullptr; + IRBlock* block = nullptr; + // + // TODO: we eventually also want an `IRInst*` for + // an instruction to insert before, so that we + // can also use the builder to insert inside + // an existing block. - IRType* getPtrType( - IRType* valueType, - IRAddressSpace addressSpace); + IRFunc* getFunc() { return func; } + IRBlock* getBlock() { return block; } - IRType* getPtrType( - IRType* valueType); + void addInst(IRBlock* block, IRInst* inst); + void addInst(IRInst* inst); IRValue* getBoolValue(bool value); IRValue* getIntValue(IRType* type, IRIntegerValue value); IRValue* getFloatValue(IRType* type, IRFloatingPointValue value); + IRValue* getDeclRefVal( + DeclRefBase const& declRef); + IRInst* emitCallInst( IRType* type, IRValue* func, @@ -448,6 +372,8 @@ struct IRBuilder IRModule* createModule(); IRFunc* createFunc(); + IRGlobalVar* createGlobalVar( + IRType* valueType); IRBlock* createBlock(); IRBlock* emitBlock(); @@ -472,12 +398,12 @@ struct IRBuilder IRInst* emitFieldExtract( IRType* type, IRValue* base, - IRStructField* field); + IRValue* field); IRInst* emitFieldAddress( IRType* type, IRValue* basePtr, - IRStructField* field); + IRValue* field); IRInst* emitElementExtract( IRType* type, @@ -556,24 +482,24 @@ struct IRBuilder IRBlock* breakBlock); IRDecoration* addDecorationImpl( - IRInst* inst, + IRValue* value, UInt decorationSize, IRDecorationOp op); template - T* addDecoration(IRInst* inst, IRDecorationOp op) + T* addDecoration(IRValue* value, IRDecorationOp op) { - return (T*) addDecorationImpl(inst, sizeof(T), op); + return (T*) addDecorationImpl(value, sizeof(T), op); } template - T* addDecoration(IRInst* inst) + T* addDecoration(IRValue* value) { - return (T*) addDecorationImpl(inst, sizeof(T), IRDecorationOp(T::kDecorationOp)); + return (T*) addDecorationImpl(value, sizeof(T), IRDecorationOp(T::kDecorationOp)); } - IRHighLevelDeclDecoration* addHighLevelDeclDecoration(IRInst* inst, Decl* decl); - IRLayoutDecoration* addLayoutDecoration(IRInst* inst, Layout* layout); + IRHighLevelDeclDecoration* addHighLevelDeclDecoration(IRValue* value, Decl* decl); + IRLayoutDecoration* addLayoutDecoration(IRValue* value, Layout* layout); }; } diff --git a/source/slang/ir.cpp b/source/slang/ir.cpp index 3e12ef1a2..ee28227e3 100644 --- a/source/slang/ir.cpp +++ b/source/slang/ir.cpp @@ -40,7 +40,7 @@ namespace Slang // - void IRUse::init(IRValue* u, IRValue* v) + void IRUse::init(IRInst* u, IRValue* v) { user = u; usedValue = v; @@ -58,10 +58,17 @@ namespace Slang IRUse* IRInst::getArgs() { - return &type; + // We assume that *all* instructions are laid out + // in memory such that their arguments come right + // after the first `sizeof(IRInst)` bytes. + // + // TODO: we probably need to be careful and make + // this more robust. + + return (IRUse*)(this + 1); } - IRDecoration* IRInst::findDecorationImpl(IRDecorationOp decorationOp) + IRDecoration* IRValue::findDecorationImpl(IRDecorationOp decorationOp) { for( auto dd = firstDecoration; dd; dd = dd->next ) { @@ -71,13 +78,28 @@ namespace Slang return nullptr; } + // IRBlock + + void IRBlock::addParam(IRParam* param) + { + if (auto lp = lastParam) + { + lp->nextParam = param; + param->prevParam = lp; + } + else + { + firstParam = param; + } + lastParam = param; + } + // IRFunc IRType* IRFunc::getResultType() { return getType()->getResultType(); } UInt IRFunc::getParamCount() { return getType()->getParamCount(); } IRType* IRFunc::getParamType(UInt index) { return getType()->getParamType(index); } - IRParam* IRFunc::getFirstParam() { auto entryBlock = getFirstBlock(); @@ -86,41 +108,20 @@ namespace Slang return entryBlock->getFirstParam(); } - // IRBlock - - IRParam* IRBlock::getFirstParam() - { - auto firstInst = firstChild; - if(!firstInst) return nullptr; - - if(firstInst->op != kIROp_Param) - return nullptr; - - return (IRParam*) firstInst; - } - - - // IRParam - - IRParam* IRParam::getNextParam() + void IRFunc::addBlock(IRBlock* block) { - // TODO: this is written as a search because we don't - // currently do the careful thing and emit parameters - // before any other members of a block. - // - // This should change on the emit side, instead. + block->parentFunc = this; - auto next = nextInst; - - for (;;) + if (auto lb = lastBlock) { - if (!next) return nullptr; - - if(next->op == kIROp_Param) - return (IRParam*) next; - - next = next->nextInst; + lb->nextBlock = block; + block->prevBlock = lb; + } + else + { + firstBlock = block; } + lastBlock = block; } // @@ -156,27 +157,27 @@ namespace Slang // // Add an instruction to a specific parent - void IRBuilder::addInst(IRParentInst* parent, IRInst* inst) + void IRBuilder::addInst(IRBlock* block, IRInst* inst) { - inst->parent = parent; + inst->parentBlock = block; - if (!parent->firstChild) + if (!block->firstInst) { inst->prevInst = nullptr; inst->nextInst = nullptr; - parent->firstChild = inst; - parent->lastChild = inst; + block->firstInst = inst; + block->lastInst = inst; } else { - auto prev = parent->lastChild; + auto prev = block->lastInst; inst->prevInst = prev; inst->nextInst = nullptr; prev->nextInst = inst; - parent->lastChild = inst; + block->lastInst = inst; } } @@ -184,19 +185,42 @@ namespace Slang void IRBuilder::addInst( IRInst* inst) { - auto parent = parentInst; + auto parent = block; if (!parent) return; addInst(parent, inst); } + static IRValue* createValueImpl( + IRBuilder* builder, + UInt size, + IROp op, + IRType* type) + { + IRValue* value = (IRValue*) malloc(size); + memset(value, 0, size); + value->op = op; + value->type = type; + return value; + } + + template + static T* createValue( + IRBuilder* builder, + IROp op, + IRType* type) + { + return (T*) createValueImpl(builder, sizeof(T), op, type); + } + + // Create an IR instruction/value and initialize it. // // In this case `argCount` and `args` represnt the // arguments *after* the type (which is a mandatory // argument for all instructions). - static IRValue* createInstImpl( + static IRInst* createInstImpl( IRBuilder* builder, UInt size, IROp op, @@ -206,26 +230,17 @@ namespace Slang UInt varArgCount = 0, IRValue* const* varArgs = nullptr) { - IRValue* inst = (IRInst*) malloc(size); + IRInst* inst = (IRInst*) malloc(size); memset(inst, 0, size); auto module = builder->getModule(); - if (!module || (type && type->op == kIROp_VoidType)) - { - // Can't or shouldn't assign an ID to this op - } - else - { - inst->id = ++module->idCounter; - } - inst->argCount = fixedArgCount + varArgCount + 1; + inst->argCount = fixedArgCount + varArgCount; inst->op = op; - auto operand = inst->getArgs(); + inst->type = type; - operand->init(inst, type); - operand++; + auto operand = inst->getArgs(); for( UInt aa = 0; aa < fixedArgCount; ++aa ) { @@ -394,7 +409,7 @@ namespace Slang bool operator==(IRInstKey const& left, IRInstKey const& right) { if(left.inst->op != right.inst->op) return false; - if(left.inst->parent != right.inst->parent) return false; + if(left.inst->parentBlock != right.inst->parentBlock) return false; if(left.inst->argCount != right.inst->argCount) return false; auto argCount = left.inst->argCount; @@ -412,7 +427,7 @@ namespace Slang int IRInstKey::GetHashCode() { auto code = Slang::GetHashCode(inst->op); - code = combineHash(code, Slang::GetHashCode(inst->parent)); + code = combineHash(code, Slang::GetHashCode(inst->parentBlock)); code = combineHash(code, Slang::GetHashCode(inst->argCount)); auto argCount = inst->argCount; @@ -424,233 +439,21 @@ namespace Slang return code; } - static IRParentInst* joinParentInstsForInsertion( - IRParentInst* left, - IRParentInst* right) - { - // Are they the same? Easy. - if(left == right) return left; - - // Have we already failed to find a location? Then bail. - if(!left) return nullptr; - if(!right) return nullptr; - - // Is one inst a parent of the other? Pick the child. - for( auto ll = left; ll; ll = ll->parent ) - { - // Did we find the right node in the parent list of the left? - if(ll == right) return left; - } - for( auto rr = right; rr; rr = rr->parent ) - { - // Did we find the left node in the parent list of the right? - if(rr == left) return right; - } - - // Seems like they are unrelated, so we should play it safe - return nullptr; - } - - - static IRInst* findOrEmitInstImpl( - IRBuilder* builder, - UInt size, - IROp op, - IRType* type, - UInt fixedArgCount, - IRValue* const* fixedArgs, - UInt varArgCount = 0, - IRValue* const* varArgs = nullptr) - { - // First, we need to pick a good insertion point - // for the instruction, which we do by looking - // at its operands. - // - - IRParentInst* parent = builder->shared->module; - if( type ) - { - parent = joinParentInstsForInsertion(parent, type->parent); - } - for( UInt aa = 0; aa < fixedArgCount; ++aa ) - { - auto arg = fixedArgs[aa]; - parent = joinParentInstsForInsertion(parent, arg->parent); - } - for( UInt aa = 0; aa < varArgCount; ++aa ) - { - auto arg = varArgs[aa]; - parent = joinParentInstsForInsertion(parent, arg->parent); - } - - // If we failed to find a good insertion point, then insert locally. - if( !parent ) - { - parent = builder->parentInst; - } - - if( parent->op == kIROp_Func ) - { - // We are trying to insert into a function, and we should really - // be inserting into its entry block. - assert(parent->firstChild); - parent = (IRBlock*) ((IRFunc*) parent)->firstChild; - } - - // We now know where we want to insert, but there might - // already be an equivalent instruction in that block. - // - // We will check for such an instruction in a slightly hacky - // way: we will construct a temporary instruction and - // then use it to look up in a cache of instructions. - - IRInst* keyInst = createInstImpl(builder, size, op, type, fixedArgCount, fixedArgs, varArgCount, varArgs); - keyInst->parent = parent; - - IRInstKey key; - key.inst = keyInst; - - IRInst* inst = nullptr; - if( builder->shared->globalValueNumberingMap.TryGetValue(key, inst) ) - { - // We found a match, so just use that. - - free(keyInst); - return inst; - } - - // No match, so use our "key" instruction for real. - inst = keyInst; - - builder->shared->globalValueNumberingMap.Add(key, inst); - - keyInst->parent = nullptr; - builder->addInst(parent, inst); - - return inst; - } - - template - static T* findOrEmitInst( - IRBuilder* builder, - IROp op, - IRType* type, - UInt argCount, - IRValue* const* args) - { - return (T*) findOrEmitInstImpl( - builder, - sizeof(T), - op, - type, - argCount, - args); - } - - template - static T* findOrEmitInst( - IRBuilder* builder, - IROp op, - IRType* type, - UInt fixedArgCount, - IRValue* const* fixedArgs, - UInt varArgCount, - IRValue* const* varArgs) - { - return (T*) findOrEmitInstImpl( - builder, - sizeof(T) + varArgCount * sizeof(IRUse), - op, - type, - fixedArgCount, - fixedArgs, - varArgCount, - varArgs); - } - - template - static T* findOrEmitInst( - IRBuilder* builder, - IROp op, - IRType* type) - { - return (T*) findOrEmitInstImpl( - builder, - sizeof(T), - op, - type, - 0, - nullptr); - } - - template - static T* findOrEmitInst( - IRBuilder* builder, - IROp op, - IRType* type, - IRInst* arg) - { - return (T*) findOrEmitInstImpl( - builder, - sizeof(T), - op, - type, - 1, - &arg); - } - - template - static T* findOrEmitInst( - IRBuilder* builder, - IROp op, - IRType* type, - IRInst* arg1, - IRInst* arg2) - { - IRInst* args[] = { arg1, arg2 }; - return (T*) findOrEmitInstImpl( - builder, - sizeof(T), - op, - type, - 2, - &args[0]); - } - - template - static T* findOrEmitInst( - IRBuilder* builder, - IROp op, - IRType* type, - IRInst* arg1, - IRInst* arg2, - IRInst* arg3) - { - IRInst* args[] = { arg1, arg2, arg3 }; - return (T*) findOrEmitInstImpl( - builder, - sizeof(T), - op, - type, - 3, - &args[0]); - } - // bool operator==(IRConstantKey const& left, IRConstantKey const& right) { - if(left.inst->op != right.inst->op) return false; - if(left.inst->type.usedValue != right.inst->type.usedValue) return false; - if(left.inst->u.ptrData[0] != right.inst->u.ptrData[0]) return false; - if(left.inst->u.ptrData[1] != right.inst->u.ptrData[1]) return false; + if(left.inst->op != right.inst->op) return false; + if(left.inst->type != right.inst->type) return false; + if(left.inst->u.ptrData[0] != right.inst->u.ptrData[0]) return false; + if(left.inst->u.ptrData[1] != right.inst->u.ptrData[1]) return false; return true; } int IRConstantKey::GetHashCode() { auto code = Slang::GetHashCode(inst->op); - code = combineHash(code, Slang::GetHashCode(inst->type.usedValue)); + code = combineHash(code, Slang::GetHashCode(inst->type)); code = combineHash(code, Slang::GetHashCode(inst->u.ptrData[0])); code = combineHash(code, Slang::GetHashCode(inst->u.ptrData[1])); return code; @@ -668,22 +471,20 @@ namespace Slang // at its operands. // - IRParentInst* parent = builder->shared->module; - IRConstant keyInst; memset(&keyInst, 0, sizeof(keyInst)); keyInst.op = op; - keyInst.type.usedValue = type; + keyInst.type = type; memcpy(&keyInst.u, value, valueSize); IRConstantKey key; key.inst = &keyInst; - IRConstant* inst = nullptr; - if( builder->shared->constantMap.TryGetValue(key, inst) ) + IRConstant* irValue = nullptr; + if( builder->shared->constantMap.TryGetValue(key, irValue) ) { // We found a match, so just use that. - return inst; + return irValue; } // We now know where we want to insert, but there might @@ -693,221 +494,25 @@ namespace Slang // way: we will construct a temporary instruction and // then use it to look up in a cache of instructions. - inst = createInst(builder, op, type); - memcpy(&inst->u, value, valueSize); + irValue = createInst(builder, op, type); + memcpy(&irValue->u, value, valueSize); - key.inst = inst; - builder->shared->constantMap.Add(key, inst); + key.inst = irValue; + builder->shared->constantMap.Add(key, irValue); - builder->addInst(parent, inst); - - return inst; + return irValue; } // - static IRType* getBaseTypeImpl(IRBuilder* builder, IROp op) - { - auto inst = findOrEmitInst( - builder, - op, - builder->getTypeType()); - return inst; - } - - IRType* IRBuilder::getBaseType(BaseType flavor) - { - switch( flavor ) - { - case BaseType::Void: return getVoidType(); - - case BaseType::Bool: return getBaseTypeImpl(this, kIROp_BoolType); - case BaseType::Float: return getBaseTypeImpl(this, kIROp_Float32Type); - case BaseType::Int: return getBaseTypeImpl(this, kIROp_Int32Type); - case BaseType::UInt: return getBaseTypeImpl(this, kIROp_UInt32Type); - - default: - SLANG_UNEXPECTED("unhandled base type"); - return nullptr; - } - } - - IRType* IRBuilder::getBoolType() - { - return getBaseType(BaseType::Bool); - } - - IRType* IRBuilder::getVectorType(IRType* elementType, IRValue* elementCount) - { - return findOrEmitInst( - this, - kIROp_VectorType, - getTypeType(), - elementType, - elementCount); - } - - IRType* IRBuilder::getMatrixType( - IRType* elementType, - IRValue* rowCount, - IRValue* columnCount) - { - return findOrEmitInst( - this, - kIROp_MatrixType, - getTypeType(), - elementType, - rowCount, - columnCount); - } - - IRType* IRBuilder::getArrayType(IRType* elementType, IRValue* elementCount) - { - // The client requests an unsized array by passing `nullptr` for - // the `elementCount`. - // - // We currently encode an unsized array as an ordinary array with - // zero elements. TODO: carefully consider this choice. - if (!elementCount) - { - elementCount = getIntValue( - getBaseType(BaseType::Int), - 0); - } - - return findOrEmitInst( - this, - kIROp_arrayType, - getTypeType(), - elementType, - elementCount); - } - - IRType* IRBuilder::getArrayType(IRType* elementType) - { - return getArrayType(elementType, nullptr); - } - - IRType* IRBuilder::getGenericParameterType(UInt index) - { - auto indexVal = getIntValue(getBaseType(BaseType::Int), index); - - return findOrEmitInst( - this, - kIROp_GenericParameterType, - getTypeType(), - indexVal); - - } - - - IRType* IRBuilder::getTypeType() - { - return findOrEmitInst( - this, - kIROp_TypeType, - nullptr); - } - - IRType* IRBuilder::getVoidType() - { - return findOrEmitInst( - this, - kIROp_VoidType, - getTypeType()); - } - - IRType* IRBuilder::getBlockType() - { - return findOrEmitInst( - this, - kIROp_BlockType, - getTypeType()); - } - - IRType* IRBuilder::getIntrinsicType( - IROp op, - UInt argCount, - IRValue* const* args) - { - return findOrEmitInst( - this, - op, - getTypeType(), - 0, - nullptr, - argCount, - args); - } - - - IRStructDecl* IRBuilder::createStructType() - { - return createInst( - this, - kIROp_StructType, - getTypeType()); - } - - IRStructField* IRBuilder::createStructField(IRType* fieldType) - { - return createInst( - this, - kIROp_StructField, - fieldType); - } - - - IRType* IRBuilder::getFuncType( - UInt paramCount, - IRType* const* paramTypes, - IRType* resultType) - { - // TODO: need to unique things here! - auto inst = createInstWithTrailingArgs( - this, - kIROp_FuncType, - getTypeType(), - 1, - (IRValue* const*) &resultType, - paramCount, - (IRValue* const*) paramTypes); - addInst(inst); - return inst; - } - - IRType* IRBuilder::getPtrType( - IRType* valueType, - IRAddressSpace addressSpace) - { - auto uintType = getBaseType(BaseType::UInt); - auto irAddressSpace = getIntValue(uintType, addressSpace); - - auto inst = findOrEmitInst( - this, - kIROp_PtrType, - getTypeType(), - valueType, - irAddressSpace); - return inst; - } - - - IRType* IRBuilder::getPtrType( - IRType* valueType) - { - return getPtrType(valueType, kIRAddressSpace_Default); - } - - IRValue* IRBuilder::getBoolValue(bool inValue) { IRIntegerValue value = inValue; return findOrEmitConstant( this, kIROp_boolConst, - getBoolType(), + getSession()->getBoolType(), sizeof(value), &value); } @@ -932,6 +537,18 @@ namespace Slang &value); } + IRValue* IRBuilder::getDeclRefVal( + DeclRefBase const& declRef) + { + // TODO: we should cache these... + auto irValue = createInst( + this, + kIROp_decl_ref, + nullptr); + irValue->declRef = declRef; + return irValue; + } + IRInst* IRBuilder::emitCallInst( IRType* type, IRValue* func, @@ -983,52 +600,69 @@ namespace Slang IRModule* IRBuilder::createModule() { - return createInst( - this, - kIROp_Module, - nullptr); + return new IRModule(); } IRFunc* IRBuilder::createFunc() { - return createInst( + return createValue( this, kIROp_Func, nullptr); } + IRGlobalVar* IRBuilder::createGlobalVar( + IRType* valueType) + { + auto ptrType = getSession()->getPtrType(valueType); + return createValue( + this, + kIROp_global_var, + ptrType); + } + IRBlock* IRBuilder::createBlock() { - return createInst( + return createValue( this, kIROp_Block, - getBlockType()); + getSession()->getIRBasicBlockType()); } IRBlock* IRBuilder::emitBlock() { - auto inst = createBlock(); - addInst(inst); - return inst; + auto bb = createBlock(); + + auto f = this->func; + if (f) + { + f->addBlock(bb); + this->block = bb; + } + return bb; } IRParam* IRBuilder::emitParam( IRType* type) { - auto inst = createInst( + auto param = createValue( this, kIROp_Param, type); - addInst(inst); - return inst; + + if (auto bb = block) + { + bb->addParam(param); + } + return param; } IRVar* IRBuilder::emitVar( IRType* type, IRAddressSpace addressSpace) { - auto allocatedType = getPtrType(type, addressSpace); + auto allocatedType = getSession()->getPtrType(type); auto inst = createInst( this, kIROp_Var, @@ -1047,14 +681,14 @@ namespace Slang IRInst* IRBuilder::emitLoad( IRValue* ptr) { - auto ptrType = ptr->getType(); - if( ptrType->op != kIROp_PtrType ) + auto ptrType = ptr->getType()->As(); + if( !ptrType ) { // Bad! return nullptr; } - auto valueType = ((IRPtrType*) ptrType)->getValueType(); + auto valueType = ptrType->getValueType(); auto inst = createInst( this, @@ -1070,11 +704,10 @@ namespace Slang IRValue* dstPtr, IRValue* srcVal) { - auto type = getVoidType(); auto inst = createInst( this, kIROp_Store, - type, + nullptr, dstPtr, srcVal); @@ -1085,7 +718,7 @@ namespace Slang IRInst* IRBuilder::emitFieldExtract( IRType* type, IRValue* base, - IRStructField* field) + IRValue* field) { auto inst = createInst( this, @@ -1101,7 +734,7 @@ namespace Slang IRInst* IRBuilder::emitFieldAddress( IRType* type, IRValue* base, - IRStructField* field) + IRValue* field) { auto inst = createInst( this, @@ -1170,7 +803,7 @@ namespace Slang UInt elementCount, UInt const* elementIndices) { - auto intType = getBaseType(BaseType::Int); + auto intType = getSession()->getBuiltinType(BaseType::Int); IRValue* irElementIndices[4]; for (UInt ii = 0; ii < elementCount; ++ii) @@ -1212,7 +845,7 @@ namespace Slang UInt elementCount, UInt const* elementIndices) { - auto intType = getBaseType(BaseType::Int); + auto intType = getSession()->getBuiltinType(BaseType::Int); IRValue* irElementIndices[4]; for (UInt ii = 0; ii < elementCount; ++ii) @@ -1229,7 +862,7 @@ namespace Slang auto inst = createInst( this, kIROp_ReturnVal, - getVoidType(), + nullptr, val); addInst(inst); return inst; @@ -1240,7 +873,7 @@ namespace Slang auto inst = createInst( this, kIROp_ReturnVoid, - getVoidType()); + nullptr); addInst(inst); return inst; } @@ -1251,7 +884,7 @@ namespace Slang auto inst = createInst( this, kIROp_unconditionalBranch, - getVoidType(), + nullptr, block); addInst(inst); return inst; @@ -1263,7 +896,7 @@ namespace Slang auto inst = createInst( this, kIROp_break, - getVoidType(), + nullptr, target); addInst(inst); return inst; @@ -1275,7 +908,7 @@ namespace Slang auto inst = createInst( this, kIROp_continue, - getVoidType(), + nullptr, target); addInst(inst); return inst; @@ -1286,13 +919,13 @@ namespace Slang IRBlock* breakBlock, IRBlock* continueBlock) { - IRInst* args[] = { target, breakBlock, continueBlock }; + IRValue* args[] = { target, breakBlock, continueBlock }; UInt argCount = sizeof(args) / sizeof(args[0]); auto inst = createInst( this, kIROp_loop, - getVoidType(), + nullptr, argCount, args); addInst(inst); @@ -1304,13 +937,13 @@ namespace Slang IRBlock* trueBlock, IRBlock* falseBlock) { - IRInst* args[] = { val, trueBlock, falseBlock }; + IRValue* args[] = { val, trueBlock, falseBlock }; UInt argCount = sizeof(args) / sizeof(args[0]); auto inst = createInst( this, kIROp_conditionalBranch, - getVoidType(), + nullptr, argCount, args); addInst(inst); @@ -1322,13 +955,13 @@ namespace Slang IRBlock* trueBlock, IRBlock* afterBlock) { - IRInst* args[] = { val, trueBlock, afterBlock }; + IRValue* args[] = { val, trueBlock, afterBlock }; UInt argCount = sizeof(args) / sizeof(args[0]); auto inst = createInst( this, kIROp_if, - getVoidType(), + nullptr, argCount, args); addInst(inst); @@ -1341,13 +974,13 @@ namespace Slang IRBlock* falseBlock, IRBlock* afterBlock) { - IRInst* args[] = { val, trueBlock, falseBlock, afterBlock }; + IRValue* args[] = { val, trueBlock, falseBlock, afterBlock }; UInt argCount = sizeof(args) / sizeof(args[0]); auto inst = createInst( this, kIROp_ifElse, - getVoidType(), + nullptr, argCount, args); addInst(inst); @@ -1359,13 +992,13 @@ namespace Slang IRBlock* bodyBlock, IRBlock* breakBlock) { - IRInst* args[] = { val, bodyBlock, breakBlock }; + IRValue* args[] = { val, bodyBlock, breakBlock }; UInt argCount = sizeof(args) / sizeof(args[0]); auto inst = createInst( this, kIROp_loopTest, - getVoidType(), + nullptr, argCount, args); addInst(inst); @@ -1373,7 +1006,7 @@ namespace Slang } IRDecoration* IRBuilder::addDecorationImpl( - IRInst* inst, + IRValue* inst, UInt decorationSize, IRDecorationOp op) { @@ -1388,14 +1021,14 @@ namespace Slang return decoration; } - IRHighLevelDeclDecoration* IRBuilder::addHighLevelDeclDecoration(IRInst* inst, Decl* decl) + IRHighLevelDeclDecoration* IRBuilder::addHighLevelDeclDecoration(IRValue* inst, Decl* decl) { auto decoration = addDecoration(inst, kIRDecorationOp_HighLevelDecl); decoration->decl = decl; return decoration; } - IRLayoutDecoration* IRBuilder::addLayoutDecoration(IRInst* inst, Layout* layout) + IRLayoutDecoration* IRBuilder::addLayoutDecoration(IRValue* inst, Layout* layout) { auto decoration = addDecoration(inst); decoration->layout = layout; @@ -1409,6 +1042,9 @@ namespace Slang { StringBuilder* builder; int indent; + + UInt idCounter = 1; + Dictionary mapValueToID; }; static void dump( @@ -1454,27 +1090,59 @@ namespace Slang } } + bool opHasResult(IRValue* inst); + + static UInt getID( + IRDumpContext* context, + IRValue* value) + { + UInt id = 0; + if (context->mapValueToID.TryGetValue(value, id)) + return id; + + if (opHasResult(value)) + { + id = context->idCounter++; + } + + context->mapValueToID.Add(value, id); + return id; + } + static void dumpID( IRDumpContext* context, - IRInst* inst) + IRValue* inst) { if (!inst) { dump(context, ""); + return; } - else if( auto mangled = inst->findDecoration() ) - { - dump(context, "@"); - dump(context, mangled->mangledName.Buffer()); - } - else if(inst->id) - { - dump(context, "%"); - dump(context, (UInt) inst->id); - } - else + + switch(inst->op) { - dump(context, "_"); + case kIROp_Func: + { + auto irFunc = (IRFunc*) inst; + dump(context, "@"); + dump(context, irFunc->mangledName.Buffer()); + } + break; + + default: + { + UInt id = getID(context, inst); + if (id) + { + dump(context, "%"); + dump(context, id); + } + else + { + dump(context, "_"); + } + } + break; } } @@ -1484,8 +1152,9 @@ namespace Slang static void dumpOperand( IRDumpContext* context, - IRInst* inst) + IRValue* inst) { + // TODO: we should have a dedicated value for the `undef` case if (!inst) { dump(context, "undef"); @@ -1514,22 +1183,90 @@ namespace Slang break; } - auto type = inst->getType(); - if (type) + dumpID(context, inst); + } + + static void dump( + IRDumpContext* context, + Name* name) + { + dump(context, getText(name).Buffer()); + } + + + static void dumpDeclRef( + IRDumpContext* context, + DeclRef const& declRef); + + static void dumpVal( + IRDumpContext* context, + Val* val) + { + if(auto type = dynamic_cast(val)) { - switch (type->op) - { - case kIROp_TypeType: - dumpType(context, (IRType*)inst); - return; + dumpType(context, type); + } + else if(auto constIntVal = dynamic_cast(val)) + { + dump(context, constIntVal->value); + } + else if(auto genericParamVal = dynamic_cast(val)) + { + dumpDeclRef(context, genericParamVal->declRef); + } + else + { + dump(context, "???"); + } + } - default: - break; - } + static void dumpDeclRef( + IRDumpContext* context, + DeclRef const& declRef) + { + auto decl = declRef.getDecl(); + + auto parentDeclRef = declRef.GetParent(); + auto genericParentDeclRef = parentDeclRef.As(); + if(genericParentDeclRef) + { + parentDeclRef = genericParentDeclRef.GetParent(); } + if(parentDeclRef.As()) + { + parentDeclRef = DeclRef(); + } - dumpID(context, inst); + if(parentDeclRef) + { + dumpDeclRef(context, parentDeclRef); + dump(context, "."); + } + dump(context, decl->getName()); + + if(genericParentDeclRef) + { + auto subst = declRef.substitutions; + if( !subst || subst->genericDecl != genericParentDeclRef.getDecl() ) + { + // No actual substitutions in place here + dump(context, "<>"); + } + else + { + auto args = subst->args; + bool first = true; + dump(context, "<"); + for(auto aa : args) + { + if(!first) dump(context, ","); + dumpVal(context, aa); + first = false; + } + dump(context, ">"); + } + } } static void dumpType( @@ -1538,10 +1275,43 @@ namespace Slang { if (!type) { - dumpID(context, type); + dump(context, "_"); return; } + if(auto funcType = type->As()) + { + UInt paramCount = funcType->getParamCount(); + dump(context, "("); + for( UInt pp = 0; pp < paramCount; ++pp ) + { + if(pp != 0) dump(context, ", "); + dumpType(context, funcType->getParamType(pp)); + } + dump(context, ") -> "); + dumpType(context, funcType->getResultType()); + } + else if(auto arrayType = type->As()) + { + dumpType(context, arrayType->baseType); + dump(context, "["); + if(auto elementCount = arrayType->ArrayLength) + { + dumpVal(context, elementCount); + } + dump(context, "]"); + } + else if(auto declRefType = type->As()) + { + dumpDeclRef(context, declRefType->declRef); + } + else + { + // Need a default case here + dump(context, "???"); + } + +#if 0 auto op = type->op; auto opInfo = kIROpInfos[op]; @@ -1551,21 +1321,6 @@ namespace Slang dumpID(context, type); break; - case kIROp_FuncType: - { - auto funcType = (IRFuncType*) type; - UInt paramCount = funcType->getParamCount(); - dump(context, "("); - for( UInt pp = 0; pp < paramCount; ++pp ) - { - if(pp != 0) dump(context, ", "); - dumpType(context, funcType->getParamType(pp)); - } - dump(context, ") -> "); - dumpType(context, funcType->getResultType()); - } - break; - default: { dump(context, opInfo.name); @@ -1585,6 +1340,7 @@ namespace Slang } break; } +#endif } static void dumpInstTypeClause( @@ -1602,32 +1358,68 @@ namespace Slang static void dumpChildrenRaw( IRDumpContext* context, - IRParentInst* parent) + IRBlock* block) { - for (auto ii = parent->firstChild; ii; ii = ii->nextInst) + for (auto ii = block->firstInst; ii; ii = ii->nextInst) { dumpInst(context, ii); } } - static void dumpChildren( + static void dumpBlock( IRDumpContext* context, - IRInst* inst) + IRBlock* block) { - auto op = inst->op; - auto opInfo = &kIROpInfos[op]; - if (opInfo->flags & kIROpFlag_Parent) + context->indent--; + dump(context, "block "); + dumpID(context, block); + + if( block->getFirstParam() ) { - dumpIndent(context); - dump(context, "{\n"); - context->indent++; - auto parent = (IRParentInst*)inst; - dumpChildrenRaw(context, parent); - context->indent--; - dumpIndent(context); - dump(context, "}\n"); + dump(context, "(\n"); + context->indent += 2; + for (auto pp = block->getFirstParam(); pp; pp = pp->getNextParam()) + { + if (pp != block->getFirstParam()) + dump(context, ",\n"); + + dumpIndent(context); + dump(context, "param "); + dumpID(context, pp); + dumpInstTypeClause(context, pp->getType()); + } + context->indent -= 2; + dump(context, ")"); + } + dump(context, ":\n"); + context->indent++; + + dumpChildrenRaw(context, block); + } + + static void dumpChildrenRaw( + IRDumpContext* context, + IRFunc* func) + { + for (auto bb = func->getFirstBlock(); bb; bb = bb->getNextBlock()) + { + dumpBlock(context, bb); } } + + static void dumpChildren( + IRDumpContext* context, + IRFunc* func) + { + dumpIndent(context); + dump(context, "{\n"); + context->indent++; + dumpChildrenRaw(context, func); + context->indent--; + dumpIndent(context); + dump(context, "}\n"); + } + static void dumpInst( IRDumpContext* context, IRInst* inst) @@ -1646,6 +1438,7 @@ namespace Slang // switch (op) { +#if 0 case kIROp_Module: dumpIndent(context); dump(context, "module\n"); @@ -1717,11 +1510,13 @@ namespace Slang dumpChildrenRaw(context, block); } return; +#endif default: break; } +#if 0 // We also want to special-case based on the *type* // of the instruction auto type = inst->getType(); @@ -1738,38 +1533,42 @@ namespace Slang return; } } +#endif // Okay, we have a seemingly "ordinary" op now dumpIndent(context); auto opInfo = &kIROpInfos[op]; + auto type = inst->getType(); - if (type && type->op == kIROp_TypeType) - { - dump(context, "type "); - dumpID(context, inst); - dump(context, "\t= "); - } - else if (type && type->op == kIROp_VoidType) + if (!type) { + // No result, okay... } else { - dump(context, "let "); - dumpID(context, inst); - dumpInstTypeClause(context, type); - dump(context, "\t= "); + auto basicType = type->As(); + if (basicType && basicType->baseType == BaseType::Void) + { + // No result, okay... + } + else + { + dump(context, "let "); + dumpID(context, inst); + dumpInstTypeClause(context, type); + dump(context, "\t= "); + } } - dump(context, opInfo->name); uint32_t argCount = inst->argCount; dump(context, "("); - for (uint32_t ii = 1; ii < argCount; ++ii) + for (uint32_t ii = 0; ii < argCount; ++ii) { - if (ii != 1) + if (ii != 0) dump(context, ", "); auto argVal = inst->getArgs()[ii].usedValue; @@ -1779,10 +1578,78 @@ namespace Slang dump(context, ")"); dump(context, "\n"); + } + + void dumpIRFunc( + IRDumpContext* context, + IRFunc* func) + { + dump(context, "\n"); + dumpIndent(context); + dump(context, "ir_func "); + dumpID(context, func); + dumpInstTypeClause(context, func->getType()); + dump(context, "\n"); + + dumpIndent(context); + dump(context, "{\n"); + context->indent++; - // The instruction might have children, - // so we need to handle those here - dumpChildren(context, inst); + for (auto bb = func->getFirstBlock(); bb; bb = bb->getNextBlock()) + { + if (bb != func->getFirstBlock()) + dump(context, "\n"); + dumpBlock(context, bb); + } + + context->indent--; + dump(context, "}\n"); + } + + void dumpIRGlobalVar( + IRDumpContext* context, + IRGlobalVar* var) + { + dump(context, "\n"); + dumpIndent(context); + dump(context, "ir_global_var "); + dumpID(context, var); + dumpInstTypeClause(context, var->getType()); + + // TODO: deal with the case where a global + // might have embedded initialization logic. + + dump(context, ";\n"); + } + + void dumpIRGlobalValue( + IRDumpContext* context, + IRGlobalValue* value) + { + switch (value->op) + { + case kIROp_Func: + dumpIRFunc(context, (IRFunc*)value); + break; + + case kIROp_global_var: + dumpIRGlobalVar(context, (IRGlobalVar*)value); + break; + + default: + dump(context, "???\n"); + break; + } + } + + void dumpIRModule( + IRDumpContext* context, + IRModule* module) + { + for (auto gv : module->globalValues) + { + dumpIRGlobalValue(context, gv); + } } void printSlangIRAssembly(StringBuilder& builder, IRModule* module) @@ -1791,7 +1658,7 @@ namespace Slang context.builder = &builder; context.indent = 0; - dumpChildrenRaw(&context, module); + dumpIRModule(&context, module); } String getSlangIRAssembly(IRModule* module) diff --git a/source/slang/ir.h b/source/slang/ir.h index 9c3124478..7e4dc8f83 100644 --- a/source/slang/ir.h +++ b/source/slang/ir.h @@ -11,36 +11,14 @@ namespace Slang { -// TODO(tfoley): We should ditch this enumeration -// and just use the IR opcodes that represent these -// types directly. The one major complication there -// is that the order of the enum values currently -// matters, since it determines promotion rank. -// We either need to keep that restriction, or -// look up promotion rank by some other means. -// -enum class BaseType -{ - // Note(tfoley): These are ordered in terms of promotion rank, so be vareful when messing with this - - Void = 0, - Bool, - Int, - UInt, - UInt64, - Half, - Float, - Double, -}; - +class FuncType; +class Layout; +class Type; -class Layout; - -struct IRFunc; -struct IRInst; -struct IRModule; -struct IRParentInst; -struct IRType; +struct IRFunc; +struct IRInst; +struct IRModule; +struct IRValue; typedef unsigned int IROpFlags; enum : IROpFlags @@ -75,34 +53,6 @@ enum IROp : int16_t }; -#if 0 -enum IRPseudoOp -{ - kIRPseudoOp_Pos = -1000, - kIRPseudoOp_PreInc, - kIRPseudoOp_PreDec, - kIRPseudoOp_PostInc, - kIRPseudoOp_PostDec, - kIRPseudoOp_Sequence, - kIRPseudoOp_AddAssign, - kIRPseudoOp_SubAssign, - kIRPseudoOp_MulAssign, - kIRPseudoOp_DivAssign, - kIRPseudoOp_ModAssign, - kIRPseudoOp_AndAssign, - kIRPseudoOp_OrAssign, - kIRPseudoOp_XorAssign , - kIRPseudoOp_LshAssign, - kIRPseudoOp_RshAssign, - kIRPseudoOp_Assign, - kIRPseudoOp_BitNot, - kIRPseudoOp_And, - kIRPseudoOp_Or, - - kIROp_Invalid = -1, -}; -#endif - IROp findIROp(char const* name); // A logical operation/opcode in the IR @@ -125,12 +75,12 @@ IROpInfo getIROpInfo(IROp op); // A use of another value/inst within an IR operation struct IRUse { + // The value that is being used + IRValue* usedValue; + // The value that is doing the using. IRInst* user; - // The value that is being used - IRInst* usedValue; - // The next use of the same value IRUse* nextUse; @@ -138,7 +88,7 @@ struct IRUse // so that we can simplify updates. IRUse** prevLink; - void init(IRInst* user, IRInst* usedValue); + void init(IRInst* user, IRValue* usedValue); }; enum IRDecorationOp : uint16_t @@ -162,52 +112,61 @@ struct IRDecoration IRDecorationOp op; }; -typedef uint32_t IRInstID; +// Use AST-level types directly to represent the +// types of IR instructions/values +typedef Type IRType; + +struct IRBlock; -// In the IR, almost *everything* is an instruction, -// in order to make the representation as uniform as possible. -struct IRInst +// Base class for values in the IR +struct IRValue { // The operation that this value represents IROp op; - // A unique ID to represent the op when printing - // (or zero to indicate that the value of this - // op isn't special). - IRInstID id; + // The type of the result value of this instruction, + // or `null` to indicate that the instruction has + // no value. + RefPtr type; - // The total number of arguments of this instruction - // (including the type) - uint32_t argCount; - - // The parent of this instruction. - // This will often be a basic block, but we - // allow instructions to nest in more general ways. - IRParentInst* parent; - - // The next and previous instructions in the same parent block - IRInst* nextInst; - IRInst* prevInst; + Type* getType() { return type; } - // The first use of this value (start of a linked list) - IRUse* firstUse; - - // The linked list of decorations attached to this instruction + // The linked list of decorations attached to this value IRDecoration* firstDecoration; + // Look up a decoration in the list of decorations IRDecoration* findDecorationImpl(IRDecorationOp op); - template T* findDecoration() { return (T*) findDecorationImpl(IRDecorationOp(T::kDecorationOp)); } + // The first use of this value (start of a linked list) + IRUse* firstUse; - // The type of this value - IRUse type; - IRType* getType() { return (IRType*) type.usedValue; } +}; + +// Instructions are values that can be executed, +// and which take other values as operands +struct IRInst : IRValue +{ + // The total number of arguments of this instruction. + // + // TODO: We shouldn't need to allocate this on + // all instructions. Instead we should have + // instructions that need "vararg" support to + // allocate this field ahead of the `this` + // pointer. + uint32_t argCount; + + // The basic block that contains this instruction. + IRBlock* parentBlock; + + // The next and previous instructions in the same parent block + IRInst* nextInst; + IRInst* prevInst; UInt getArgCount() { @@ -216,20 +175,16 @@ struct IRInst IRUse* getArgs(); - IRInst* getArg(UInt index) + IRValue* getArg(UInt index) { return getArgs()[index].usedValue; } }; -// This type alias exists because I waffled on the name for a bit. -// All existing uses of `IRValue` should move to `IRInst` -typedef IRInst IRValue; - typedef int64_t IRIntegerValue; typedef double IRFloatingPointValue; -struct IRConstant : IRInst +struct IRConstant : IRValue { union { @@ -241,16 +196,6 @@ struct IRConstant : IRInst } u; }; -// Representation of a type at the IR level. -// Such a type may not correspond to the high-level-language notion -// of a type as used by the front end. -// -// Note that types are instructions in the IR, so that operations -// may take type operands as easily as values. -struct IRType : IRInst -{ -}; - // A instruction that ends a basic block (usually because of control flow) struct IRTerminatorInst : IRInst {}; @@ -258,23 +203,20 @@ struct IRTerminatorInst : IRInst bool isTerminatorInst(IROp op); bool isTerminatorInst(IRInst* inst); - -// A parent instruction contains a sequence of other instructions +// A function parameter is owned by a basic block, and represents +// either an incoming function parameter (in the entry block), or +// a value that flows from one SSA block to another (in a non-entry +// block). // -struct IRParentInst : IRInst +// In each case, the basic idea is that a block is a "label with +// arguments." +struct IRParam : IRValue { - // The first and last instruction in the container (or NULL in - // the case that the container is empty). - // - IRInst* firstChild; - IRInst* lastChild; -}; + IRParam* nextParam; + IRParam* prevParam; -// A function parameter is represented by an instruction -// in the entry block of a function. -struct IRParam : IRInst -{ - IRParam* getNextParam(); + IRParam* getNextParam() { return nextParam; } + IRParam* getPrevParam() { return prevParam; } }; // A basic block is a parent instruction that adds the constraint @@ -282,48 +224,97 @@ struct IRParam : IRInst // no function declarations, or nested blocks). We also expect // that the previous/next instruction are always a basic block. // -struct IRBlock : IRParentInst +struct IRBlock : IRValue { + // Linked list of the instructions contained in this block + // // Note that in a valid program, every block must end with // a "terminator" instruction, so these should be non-NULL, - // and `last` should actually be an `IRTerminatorInst`. + // and `lastInst` should actually be an `IRTerminatorInst`. + IRInst* firstInst; + IRInst* lastInst; - IRBlock* getPrevBlock() { return (IRBlock*) prevInst; } - IRBlock* getNextBlock() { return (IRBlock*) nextInst; } + IRInst* getFirstInst() { return firstInst; } + IRInst* getLastInst() { return lastInst; } - IRFunc* getParent() { return (IRFunc*)parent; } + // Links for the list of basic blocks in the parent function + IRBlock* prevBlock; + IRBlock* nextBlock; + + IRBlock* getPrevBlock() { return prevBlock; } + IRBlock* getNextBlock() { return nextBlock; } + + // Linked list of parameters of this block + IRParam* firstParam; + IRParam* lastParam; + + IRParam* getFirstParam() { return firstParam; } + IRParam* getLastParam() { return lastParam; } + void addParam(IRParam* param); + + // The parent function that contains this block + IRFunc* parentFunc; + + IRFunc* getParent() { return parentFunc; } - IRParam* getFirstParam(); }; -struct IRFuncType; +// For right now, we will represent the type of +// an IR function using the type of the AST +// function from which it was created. +// +// TODO: need to do this better. +typedef FuncType IRFuncType; + +struct IRGlobalValue : IRValue +{}; // A function is a parent to zero or more blocks of instructions. // // A function is itself a value, so that it can be a direct operand of // an instruction (e.g., a call). -struct IRFunc : IRParentInst +struct IRFunc : IRGlobalValue { - IRFuncType* getType() { return (IRFuncType*) type.usedValue; } + // The type of the IR-level function + IRFuncType* getType() { return (IRFuncType*) type.Ptr(); } + + // The mangled name, for a function + // that should have linkage. + String mangledName; - IRType* getResultType(); + // Convenience accessors for working with the + // function's type. + Type* getResultType(); UInt getParamCount(); - IRType* getParamType(UInt index); + Type* getParamType(UInt index); - IRBlock* getFirstBlock() { return (IRBlock*) firstChild; } - IRBlock* getLastBlock() { return (IRBlock*) lastChild; } + // The list of basic blocks in this function + IRBlock* firstBlock = nullptr; + IRBlock* lastBlock = nullptr; + IRBlock* getFirstBlock() { return firstBlock; } + IRBlock* getLastBlock() { return lastBlock; } + + // Add a block to the end of this function. + void addBlock(IRBlock* block); + + // Convenience accessor for the IR parameters, + // which are actually the parameters of the first + // block. IRParam* getFirstParam(); }; // A module is a parent to functions, global variables, types, etc. -struct IRModule : IRParentInst +struct IRModule : RefObject { // The designated entry-point function, if any IRFunc* entryPoint; - // A special counter used to assign logical ids to instructions in this module. - IRInstID idCounter; + // A list of all the functions and other + // global values declared in this module. + List globalValues; + + // TODO: need a symbol of all the global variables too }; void printSlangIRAssembly(StringBuilder& builder, IRModule* module); diff --git a/source/slang/lower-to-ir.cpp b/source/slang/lower-to-ir.cpp index 06ad66bc4..7fce0c385 100644 --- a/source/slang/lower-to-ir.cpp +++ b/source/slang/lower-to-ir.cpp @@ -79,7 +79,7 @@ struct SubscriptInfo : ExtendedValueInfo struct BoundSubscriptInfo : ExtendedValueInfo { DeclRef declRef; - IRType* type; + RefPtr type; List args; UInt genericArgCount; }; @@ -218,8 +218,8 @@ struct BoundMemberInfo : ExtendedValueInfo // struct SwizzledLValueInfo : ExtendedValueInfo { - // IR-level The type of the expression. - IRType* type; + // The type of the expression. + RefPtr type; // The base expression (this should be an l-value) LoweredValInfo base; @@ -355,7 +355,7 @@ LoweredValInfo emitCompoundAssignOp( auto leftVal = builder->emitLoad(leftPtr); - IRInst* innerArgs[] = { leftVal, rightVal }; + IRValue* innerArgs[] = { leftVal, rightVal }; auto innerOp = builder->emitIntrinsicInst(type, op, 2, innerArgs); builder->emitStore(leftPtr, innerOp); @@ -363,23 +363,31 @@ LoweredValInfo emitCompoundAssignOp( return LoweredValInfo::ptr(leftPtr); } -IRInst* getOneValOfType( +IRValue* getOneValOfType( IRGenContext* context, IRType* type) { - switch(type->op) + if (auto basicType = dynamic_cast(type)) { - case kIROp_Int32Type: - case kIROp_UInt32Type: - return context->irBuilder->getIntValue(type, 1); + switch (basicType->baseType) + { + case BaseType::Int: + case BaseType::UInt: + case BaseType::UInt64: + return context->irBuilder->getIntValue(type, 1); - case kIROp_Float32Type: - return context->irBuilder->getFloatValue(type, 1.0); + case BaseType::Float: + case BaseType::Double: + return context->irBuilder->getFloatValue(type, 1.0); - default: - SLANG_UNEXPECTED("inc/dec type"); - return nullptr; + default: + break; + } } + // TODO: should make sure to handle vector and matrix types here + + SLANG_UNEXPECTED("inc/dec type"); + return nullptr; } LoweredValInfo emitPreOp( @@ -396,9 +404,9 @@ LoweredValInfo emitPreOp( auto preVal = builder->emitLoad(argPtr); - IRInst* oneVal = getOneValOfType(context, type); + IRValue* oneVal = getOneValOfType(context, type); - IRInst* innerArgs[] = { preVal, oneVal }; + IRValue* innerArgs[] = { preVal, oneVal }; auto innerOp = builder->emitIntrinsicInst(type, op, 2, innerArgs); builder->emitStore(argPtr, innerOp); @@ -420,9 +428,9 @@ LoweredValInfo emitPostOp( auto preVal = builder->emitLoad(argPtr); - IRInst* oneVal = getOneValOfType(context, type); + IRValue* oneVal = getOneValOfType(context, type); - IRInst* innerArgs[] = { preVal, oneVal }; + IRValue* innerArgs[] = { preVal, oneVal }; auto innerOp = builder->emitIntrinsicInst(type, op, 2, innerArgs); builder->emitStore(argPtr, innerOp); @@ -647,10 +655,7 @@ struct LoweredTypeInfo Simple, }; - union - { - IRType* type; - }; + RefPtr type; Flavor flavor; LoweredTypeInfo() @@ -743,6 +748,35 @@ LoweredValInfo lowerDecl( DeclBase* decl, Layout* layout); +IRType* getIntType( + IRGenContext* context) +{ + return context->getSession()->getBuiltinType(BaseType::Int); +} + +// Get a pointer type to the given element type +RefPtr getPtrType( + IRGenContext* context, + IRType* valueType) +{ + return context->getSession()->getPtrType(valueType); +} + +RefPtr getFuncType( + IRGenContext* context, + UInt paramCount, + RefPtr const* paramTypes, + IRType* resultType) +{ + RefPtr funcType = new FuncType(); + funcType->resultType = resultType; + for (UInt pp = 0; pp < paramCount; ++pp) + { + funcType->paramTypes.Add(paramTypes[pp]); + } + return funcType; +} + // struct ValLoweringVisitor : ValVisitor @@ -761,7 +795,7 @@ struct ValLoweringVisitor : ValVisitorgetBaseType(BaseType::Int); + auto type = getIntType(context); return LoweredValInfo::simple(getBuilder()->getIntValue(type, val->value)); } @@ -772,16 +806,7 @@ struct ValLoweringVisitor : ValVisitordeclRef); - auto loweredFuncVal = getSimpleVal(context, loweredFunc); - - // HACK: deal with the case where the decl might not - // lower to anything, and so we don't have a type to - // work with. - if (!loweredFuncVal) - return LoweredTypeInfo(); - - return loweredFuncVal->getType(); + return LoweredTypeInfo(type); } void addGenericArgs(List* ioArgs, DeclRefBase declRef) @@ -799,12 +824,16 @@ struct ValLoweringVisitor : ValVisitordeclRef.getDecl()->FindModifier() ) { auto builder = getBuilder(); - auto intType = builder->getBaseType(BaseType::Int); + auto intType = getIntType(context); // List irArgs; for( auto val : intrinsicTypeMod->irOperands ) @@ -831,61 +860,32 @@ struct ValLoweringVisitor : ValVisitorgetBaseType(type->baseType); + return LoweredTypeInfo(type); } LoweredTypeInfo visitVectorExpressionType(VectorExpressionType* type) { - auto irElementType = lowerSimpleType(context, type->elementType); - auto irElementCount = lowerSimpleVal(context, type->elementCount); - - return getBuilder()->getVectorType(irElementType, irElementCount); + return LoweredTypeInfo(type); } LoweredTypeInfo visitMatrixExpressionType(MatrixExpressionType* type) { - auto irElementType = lowerSimpleType(context, type->getElementType()); - auto irRowCount = lowerSimpleVal(context, type->getRowCount()); - auto irColumnCount = lowerSimpleVal(context, type->getColumnCount()); - - return getBuilder()->getMatrixType(irElementType, irRowCount, irColumnCount); + return LoweredTypeInfo(type); } - LoweredTypeInfo getArrayType( - LoweredTypeInfo const& loweredElementType, - IRValue* irElementCount) + LoweredTypeInfo visitArrayExpressionType(ArrayExpressionType* type) { - switch (loweredElementType.flavor) - { - case LoweredTypeInfo::Flavor::Simple: - return getBuilder()->getArrayType( - loweredElementType.type, - irElementCount); - break; - - default: - SLANG_UNEXPECTED("array element type"); - break; - } + return LoweredTypeInfo(type); } - LoweredTypeInfo visitArrayExpressionType(ArrayExpressionType* type) + LoweredTypeInfo visitIRBasicBlockType(IRBasicBlockType* type) { - auto loweredElementType = lowerType(context, type->baseType); - if (auto elementCount = type->ArrayLength) - { - auto irElementCount = lowerSimpleVal(context, elementCount); - return getArrayType(loweredElementType, irElementCount); - } - else - { - return getArrayType(loweredElementType, nullptr); - } + return LoweredTypeInfo(type); } }; @@ -1017,9 +1017,7 @@ struct ExprLoweringVisitorBase : ExprVisitor if (auto fieldDeclRef = declRef.As()) { // Okay, easy enough: we have a reference to a field of a struct type... - - auto loweredField = ensureDecl(context, fieldDeclRef); - return extractField(loweredType, loweredBase, loweredField); + return extractField(loweredType, loweredBase, fieldDeclRef); } else if (auto callableDeclRef = declRef.As()) { @@ -1045,14 +1043,12 @@ struct ExprLoweringVisitorBase : ExprVisitor // in order for a dereference to make senese, so we just // need to extract the value type from that pointer here. // - auto loweredBaseVal = getSimpleVal(context, loweredBase); - auto loweredBaseType = loweredBaseVal->getType(); - switch( loweredBaseType->op ) - { - case kIROp_PtrType: - // TODO: should we enumerate these explicitly? - case kIROp_ConstantBufferType: - case kIROp_TextureBufferType: + IRValue* loweredBaseVal = getSimpleVal(context, loweredBase); + RefPtr loweredBaseType = loweredBaseVal->getType(); + + if (loweredBaseType->As() + || loweredBaseType->As()) + { // Note that we do *not* perform an actual `load` operation // here, but rather just use the pointer value to construct // an appropriate `LoweredValInfo` representing the underlying @@ -1064,8 +1060,9 @@ struct ExprLoweringVisitorBase : ExprVisitor // and is just a bit of pointer math. // return LoweredValInfo::ptr(loweredBaseVal); - - default: + } + else + { SLANG_UNIMPLEMENTED_X("codegen for deref expression"); return LoweredValInfo(); } @@ -1472,7 +1469,7 @@ struct ExprLoweringVisitorBase : ExprVisitor case LoweredValInfo::Flavor::Ptr: return LoweredValInfo::ptr( builder->emitElementAddress( - builder->getPtrType(getSimpleType(type)), + getPtrType(context, getSimpleType(type)), baseVal.val, indexVal)); @@ -1484,9 +1481,9 @@ struct ExprLoweringVisitorBase : ExprVisitor } LoweredValInfo extractField( - LoweredTypeInfo fieldType, - LoweredValInfo base, - LoweredValInfo field) + LoweredTypeInfo fieldType, + LoweredValInfo base, + DeclRef field) { switch (base.flavor) { @@ -1497,7 +1494,7 @@ struct ExprLoweringVisitorBase : ExprVisitor getBuilder()->emitFieldExtract( getSimpleType(fieldType), irBase, - (IRStructField*) getSimpleVal(context, field))); + getBuilder()->getDeclRefVal(field))); } break; @@ -1509,9 +1506,9 @@ struct ExprLoweringVisitorBase : ExprVisitor IRValue* irBasePtr = base.val; return LoweredValInfo::ptr( getBuilder()->emitFieldAddress( - getBuilder()->getPtrType(getSimpleType(fieldType)), + getPtrType(context, getSimpleType(fieldType)), irBasePtr, - (IRStructField*) getSimpleVal(context, field))); + getBuilder()->getDeclRefVal(field))); } break; } @@ -1598,7 +1595,7 @@ struct RValueExprLoweringVisitor : ExprLoweringVisitorBasegetBaseType(BaseType::Int); + auto irIntType = getIntType(context); UInt elementCount = (UInt)expr->elementCount; IRValue* irElementIndices[4]; @@ -1661,34 +1658,22 @@ struct StmtLoweringVisitor : StmtVisitor void insertBlock(IRBlock* block) { auto builder = getBuilder(); - auto parent = builder->parentInst; - IRBlock* prevBlock = nullptr; - IRFunc* parentFunc = nullptr; - - switch (parent->op) - { - case kIROp_Block: - prevBlock = (IRBlock*)parent; - parentFunc = prevBlock->getParent(); - break; - - default: - SLANG_UNEXPECTED("bad parent kind for block"); - return; - } + auto prevBlock = builder->block; + auto parentFunc = prevBlock->parentFunc; // If the previous block doesn't already have // a terminator instruction, then be sure to // emit a branch to the new block. - if (!isTerminatorInst(prevBlock->lastChild)) + if (!isTerminatorInst(prevBlock->lastInst)) { builder->emitBranch(block); } - builder->parentInst = parentFunc; - builder->addInst(block); - builder->parentInst = block; + parentFunc->addBlock(block); + + builder->func = parentFunc; + builder->block = block; } // Start a new block at the current location. @@ -2025,8 +2010,8 @@ top: auto loweredBase = swizzleInfo->base; // Load from the base value: - IRInst* irLeftVal = getSimpleVal(context, loweredBase); - auto irRightVal = getSimpleVal(context, right); + IRValue* irLeftVal = getSimpleVal(context, loweredBase); + IRValue* irRightVal = getSimpleVal(context, right); // Now apply the swizzle IRInst* irSwizzled = builder->emitSwizzleSet( @@ -2066,7 +2051,7 @@ top: emitCallToDeclRef( context, - builder->getVoidType(), + context->getSession()->getVoidType(), setterDeclRef, allArgs, subscriptInfo->genericArgCount); @@ -2153,8 +2138,67 @@ struct DeclLoweringVisitor : DeclVisitor return LoweredValInfo(); } + bool isGlobalVarDecl(VarDeclBase* decl) + { + auto parent = decl->ParentDecl; + if (dynamic_cast(parent)) + { + // Variable declared at global scope? -> Global. + return true; + } + + return false; + } + + LoweredValInfo lowerGlobalVarDecl(VarDeclBase* decl) + { + auto varType = lowerSimpleType(context, decl->getType()); + + IRAddressSpace addressSpace = kIRAddressSpace_Default; + if (decl->HasModifier()) + { + addressSpace = kIRAddressSpace_GroupShared; + } + + auto builder = getBuilder(); + auto irGlobal = builder->createGlobalVar(varType); + + if (decl) + { + builder->addHighLevelDeclDecoration(irGlobal, decl); + } + + if (auto layout = getLayout()) + { + builder->addLayoutDecoration(irGlobal, layout); + } + + // A global variable's SSA value is a *pointer* to + // the underlying storage. + auto globalVal = LoweredValInfo::ptr(irGlobal); + context->shared->declValues.Add( + DeclRef(decl, nullptr), + globalVal); + + if( auto initExpr = decl->initExpr ) + { + // TODO: need to handle global with initializer! + } + + getBuilder()->getModule()->globalValues.Add(irGlobal); + + return globalVal; + } + LoweredValInfo visitVarDeclBase(VarDeclBase* decl) { + // Detect global (or effectively global) variables + // and handle them differently. + if (isGlobalVarDecl(decl)) + { + return lowerGlobalVarDecl(decl); + } + // A user-defined variable declaration will usually turn into // an `alloca` operation for the variable's storage, // plus some code to initialize it and then store to the variable. @@ -2219,6 +2263,7 @@ struct DeclLoweringVisitor : DeclVisitor LoweredValInfo visitAggTypeDecl(AggTypeDecl* decl) { +#if 0 // User-defined aggregate type: need to translate into // a corresponding IR aggregate type. @@ -2257,6 +2302,10 @@ struct DeclLoweringVisitor : DeclVisitor builder->addInst(irStruct); return LoweredValInfo::simple(irStruct); +#else + // TODO: What is there to do with a `struct` type? + return LoweredValInfo(); +#endif } // Sometimes we need to refer to a declaration the way that it would be specialized @@ -2592,7 +2641,7 @@ struct DeclLoweringVisitor : DeclVisitor } void trySetMangledName( - IRInst* inst, + IRFunc* irFunc, Decl* decl) { // We want to generate a mangled name for the given declaration and attach @@ -2605,8 +2654,7 @@ struct DeclLoweringVisitor : DeclVisitor String mangledName = getMangledName(decl); - auto decoration = getBuilder()->addDecoration(inst); - decoration->mangledName = mangledName; + irFunc->mangledName = mangledName; } @@ -2649,11 +2697,11 @@ struct DeclLoweringVisitor : DeclVisitor // need to create an IR function here IRFunc* irFunc = subBuilder->createFunc(); - subBuilder->parentInst = irFunc; + subBuilder->func = irFunc; trySetMangledName(irFunc, decl); - List paramTypes; + List> paramTypes; // We first need to walk the generic parameters (if any) // because these will influence the declared type of @@ -2662,6 +2710,7 @@ struct DeclLoweringVisitor : DeclVisitor for( auto genericParamDecl : parameterLists.genericParams ) { UInt genericParamIndex = genericParamCounter++; +#if 0 if( auto genericTypeParamDecl = dynamic_cast(genericParamDecl) ) { // In the logical type for the function, a generic @@ -2675,10 +2724,11 @@ struct DeclLoweringVisitor : DeclVisitor // to the appropriate generic parameter position. IRType* irParameterType = context->irBuilder->getGenericParameterType(genericParamIndex); - LoweredValInfo LoweredValInfo = LoweredValInfo::simple(irParameterType); + LoweredValInfo LoweredValInfo = LoweredValInfo::type(irParameterType); subContext->shared->declValues[makeDeclRef(genericTypeParamDecl)] = LoweredValInfo; } else +#endif { // TODO: handle the other cases here. SLANG_UNEXPECTED("generic parameter kind"); @@ -2702,7 +2752,7 @@ struct DeclLoweringVisitor : DeclVisitor // // TODO: Is this the best representation we can use? - auto irPtrType = subBuilder->getPtrType(irParamType); + auto irPtrType = getPtrType(context, irParamType); paramTypes.Add(irPtrType); } } @@ -2726,14 +2776,15 @@ struct DeclLoweringVisitor : DeclVisitor // Instead, a setter always returns `void` // - irResultType = getBuilder()->getVoidType(); + irResultType = context->getSession()->getVoidType(); } - auto irFuncType = getBuilder()->getFuncType( + auto irFuncType = getFuncType( + context, paramTypes.Count(), paramTypes.Buffer(), irResultType); - irFunc->type.init(irFunc, irFuncType); + irFunc->type = irFuncType; if (!decl->Body) { @@ -2753,7 +2804,7 @@ struct DeclLoweringVisitor : DeclVisitor // This is a function definition, so we need to actually // construct IR for the body... IRBlock* entryBlock = subBuilder->emitBlock(); - subBuilder->parentInst = entryBlock; + subBuilder->block = entryBlock; UInt paramTypeIndex = 0; for( auto paramInfo : parameterLists.params ) @@ -2771,7 +2822,7 @@ struct DeclLoweringVisitor : DeclVisitor // // TODO: Is this the best representation we can use? - auto irPtrType = (IRPtrType*)irParamType; + auto irPtrType = irParamType.As(); IRParam* irParamPtr = subBuilder->emitParam(irPtrType); if(auto paramDecl = paramInfo.decl) @@ -2829,14 +2880,21 @@ struct DeclLoweringVisitor : DeclVisitor // We need to carefully add a terminator instruction to the end // of the body, in case the user didn't do so. - if (!isTerminatorInst(subContext->irBuilder->parentInst->lastChild)) + if (!isTerminatorInst(subContext->irBuilder->block->lastInst)) { - if (irResultType->op == kIROp_VoidType) + if (irResultType->Equals(context->getSession()->getVoidType())) { + // `void`-returning function can get an implicit + // return on exit of the body statement. subContext->irBuilder->emitReturn(); } else { + // Value-returning function is expected to `return` + // on every control-flow path. We need to enforce + // this by putting an `unreachable` terminator here, + // and then emit a dataflow error if this block + // can't be eliminated. SLANG_UNEXPECTED("Needed a return here"); subContext->irBuilder->emitReturn(); } @@ -2845,7 +2903,7 @@ struct DeclLoweringVisitor : DeclVisitor getBuilder()->addHighLevelDeclDecoration(irFunc, decl); - getBuilder()->addInst(irFunc); + getBuilder()->getModule()->globalValues.Add(irFunc); return LoweredValInfo::simple(irFunc); } @@ -2878,7 +2936,6 @@ LoweredValInfo ensureDecl( IRBuilder subIRBuilder; subIRBuilder.shared = context->irBuilder->shared; - subIRBuilder.parentInst = subIRBuilder.shared->module; IRGenContext subContext = *context; @@ -2997,15 +3054,14 @@ IRModule* lowerEntryPointToIR( SharedIRBuilder sharedBuilderStorage; SharedIRBuilder* sharedBuilder = &sharedBuilderStorage; sharedBuilder->module = nullptr; + sharedBuilder->session = entryPoint->compileRequest->mSession; IRBuilder builderStorage; IRBuilder* builder = &builderStorage; builder->shared = sharedBuilder; - builder->parentInst = nullptr; IRModule* module = builder->createModule(); sharedBuilder->module = module; - builder->parentInst = module; context->irBuilder = builder; diff --git a/source/slang/lower.cpp b/source/slang/lower.cpp index 85f19f14c..6d848062c 100644 --- a/source/slang/lower.cpp +++ b/source/slang/lower.cpp @@ -708,6 +708,12 @@ struct LoweringVisitor return result; } + RefPtr visitIRBasicBlockType(IRBasicBlockType* type) + { + return type; + } + + RefPtr visitErrorType(ErrorType* type) { return type; @@ -732,9 +738,16 @@ struct LoweringVisitor RefPtr visitFuncType(FuncType* type) { - RefPtr loweredType = getFuncType( - getSession(), - translateDeclRef(DeclRef(type->declRef)).As()); + RefPtr loweredType = new FuncType(); + loweredType->resultType = lowerType(type->resultType); + for (auto paramType : type->paramTypes) + { + auto loweredParamType = lowerType(paramType); + + // TODO: it seems like this step needs to scalarize + // in the case where a parameter type is a tuple... + loweredType->paramTypes.Add(loweredParamType); + } return loweredType; } diff --git a/source/slang/syntax.cpp b/source/slang/syntax.cpp index 475802cb1..4dbb6c05b 100644 --- a/source/slang/syntax.cpp +++ b/source/slang/syntax.cpp @@ -216,6 +216,9 @@ void Type::accept(IValVisitor* visitor, void* extra) overloadedType = new OverloadGroupType(); overloadedType->setSession(this); + + irBasicBlockType = new IRBasicBlockType(); + irBasicBlockType->setSession(this); } Type* Session::getBoolType() @@ -268,6 +271,29 @@ void Type::accept(IValVisitor* visitor, void* extra) return errorType; } + Type* Session::getIRBasicBlockType() + { + return irBasicBlockType; + } + + RefPtr Session::getPtrType( + RefPtr valueType) + { + auto genericDecl = findMagicDecl( + this, "PtrType").As(); + auto typeDecl = genericDecl->inner; + + auto substitutions = new Substitutions(); + substitutions->genericDecl = genericDecl.Ptr(); + substitutions->args.Add(valueType); + + auto declRef = DeclRef(typeDecl.Ptr(), substitutions); + + return DeclRefType::Create( + this, + declRef)->As(); + } + SyntaxClass Session::findSyntaxClass(Name* name) { SyntaxClass syntaxClass; @@ -545,7 +571,26 @@ void Type::accept(IValVisitor* visitor, void* extra) else { - SLANG_UNEXPECTED("unhandled type"); + auto classInfo = session->findSyntaxClass( + session->getNamePool()->getName(magicMod->name)); + if (!classInfo.classInfo) + { + SLANG_UNEXPECTED("unhandled type"); + } + + auto type = classInfo.createInstance(); + if (!type) + { + SLANG_UNEXPECTED("constructor failure"); + } + + auto declRefType = dynamic_cast(type); + if (!declRefType) + { + SLANG_UNEXPECTED("expected a declaration reference type"); + } + declRefType->declRef = declRef; + return declRefType; } } else @@ -578,6 +623,28 @@ void Type::accept(IValVisitor* visitor, void* extra) return (int)(int64_t)(void*)this; } + // IRBasicBlockType + + String IRBasicBlockType::ToString() + { + return "Block"; + } + + bool IRBasicBlockType::EqualsImpl(Type * /*type*/) + { + return false; + } + + Type* IRBasicBlockType::CreateCanonicalType() + { + return this; + } + + int IRBasicBlockType::GetHashCode() + { + return (int)(int64_t)(void*)this; + } + // InitializerListType String InitializerListType::ToString() @@ -653,18 +720,43 @@ void Type::accept(IValVisitor* visitor, void* extra) String FuncType::ToString() { - // TODO: a better approach than this - if (declRef) - return getText(declRef.GetName()); - else - return "/* unknown FuncType */"; + StringBuilder sb; + sb << "("; + UInt paramCount = getParamCount(); + for (UInt pp = 0; pp < paramCount; ++pp) + { + if (pp != 0) sb << ", "; + sb << getParamType(pp)->ToString(); + } + sb << ") -> "; + sb << getResultType()->ToString(); + return sb.ProduceString(); } bool FuncType::EqualsImpl(Type * type) { if (auto funcType = type->As()) { - return declRef == funcType->declRef; + auto paramCount = getParamCount(); + auto otherParamCount = funcType->getParamCount(); + if (paramCount != otherParamCount) + return false; + + for (UInt pp = 0; pp < paramCount; ++pp) + { + auto paramType = getParamType(pp); + auto otherParamType = funcType->getParamType(pp); + if (!paramType->Equals(otherParamType)) + return false; + } + + if(!resultType->Equals(funcType->resultType)) + return false; + + // TODO: if we ever introduce other kinds + // of qualification on function types, we'd + // want to consider it here. + return true; } return false; } @@ -676,7 +768,16 @@ void Type::accept(IValVisitor* visitor, void* extra) int FuncType::GetHashCode() { - return declRef.GetHashCode(); + int hashCode = getResultType()->GetHashCode(); + UInt paramCount = getParamCount(); + hashCode = combineHash(hashCode, Slang::GetHashCode(paramCount)); + for (UInt pp = 0; pp < paramCount; ++pp) + { + hashCode = combineHash( + hashCode, + getParamType(pp)->GetHashCode()); + } + return hashCode; } // TypeType @@ -782,6 +883,13 @@ void Type::accept(IValVisitor* visitor, void* extra) return this->declRef.substitutions->args[2].As().Ptr(); } + // PtrTypeBase + + Type* PtrTypeBase::getValueType() + { + return this->declRef.substitutions->args[0].As().Ptr(); + } + // GenericParamIntVal bool GenericParamIntVal::EqualsVal(Val* val) @@ -1161,7 +1269,13 @@ void Type::accept(IValVisitor* visitor, void* extra) { auto funcType = new FuncType(); funcType->setSession(session); - funcType->declRef = declRef; + + funcType->resultType = GetResultType(declRef); + for (auto pp : GetParameters(declRef)) + { + funcType->paramTypes.Add(GetType(pp)); + } + return funcType; } diff --git a/source/slang/syntax.h b/source/slang/syntax.h index 8d6a1edbc..f5e22d9b5 100644 --- a/source/slang/syntax.h +++ b/source/slang/syntax.h @@ -74,6 +74,29 @@ namespace Slang kConversionCost_ScalarToVector = 1, }; + // TODO(tfoley): We should ditch this enumeration + // and just use the IR opcodes that represent these + // types directly. The one major complication there + // is that the order of the enum values currently + // matters, since it determines promotion rank. + // We either need to keep that restriction, or + // look up promotion rank by some other means. + // + enum class BaseType + { + // Note(tfoley): These are ordered in terms of promotion rank, so be vareful when messing with this + + Void = 0, + Bool, + Int, + UInt, + UInt64, + Half, + Float, + Double, + }; + + // Forward-declare all syntax classes #define SYNTAX_CLASS(NAME, BASE, ...) class NAME; diff --git a/source/slang/type-defs.h b/source/slang/type-defs.h index 7883b5c42..7fbefc368 100644 --- a/source/slang/type-defs.h +++ b/source/slang/type-defs.h @@ -41,6 +41,20 @@ protected: ) END_SYNTAX_CLASS() +// The type of a reference to a basic block +// in our IR +SYNTAX_CLASS(IRBasicBlockType, Type) +RAW( +public: + virtual String ToString() override; + +protected: + virtual bool EqualsImpl(Type * type) override; + virtual Type* CreateCanonicalType() override; + virtual int GetHashCode() override; +) +END_SYNTAX_CLASS() + // A type that takes the form of a reference to some declaration SYNTAX_CLASS(DeclRefType, Type) DECL_FIELD(DeclRef, declRef) @@ -221,6 +235,8 @@ END_SYNTAX_CLASS() // Other cases of generic types known to the compiler SYNTAX_CLASS(BuiltinGenericType, DeclRefType) SYNTAX_FIELD(RefPtr, elementType) + + RAW(Type* getElementType() { return elementType; }) END_SYNTAX_CLASS() // Types that behave like pointers, in that they can be @@ -353,6 +369,33 @@ protected: ) END_SYNTAX_CLASS() +// Base class for types that map down to +// simple pointers as part of code generation. +SYNTAX_CLASS(PtrTypeBase, DeclRefType) +RAW( + // Get the type of the pointed-to value. + Type* getValueType(); +) +END_SYNTAX_CLASS() + +// A true (user-visible) pointer type, e.g., `T*` +SYNTAX_CLASS(PtrType, PtrTypeBase) +END_SYNTAX_CLASS() + +// A type that represents the behind-the-scenes +// logical pointer that is passed for an `out` +// or `in out` parameter +SYNTAX_CLASS(OutTypeBase, PtrTypeBase) +END_SYNTAX_CLASS() + +// The type for an `out` parameter, e.g., `out T` +SYNTAX_CLASS(OutType, OutTypeBase) +END_SYNTAX_CLASS() + +// The type for an `in out` parameter, e.g., `in out T` +SYNTAX_CLASS(InOutType, OutTypeBase) +END_SYNTAX_CLASS() + // A type alias of some kind (e.g., via `typedef`) SYNTAX_CLASS(NamedExpressionType, Type) DECL_FIELD(DeclRef, declRef) @@ -375,17 +418,27 @@ protected: ) END_SYNTAX_CLASS() -// Function types are currently used for references to symbols that name -// either ordinary functions, or "component functions." -// We do not directly store a representation of the type, and instead -// use a reference to the symbol to stand in for its logical type +// A function type is defined by its parameter types +// and its result type. SYNTAX_CLASS(FuncType, Type) - DECL_FIELD(DeclRef, declRef) + + // TODO: We may want to preserve parameter names + // in the list here, just so that we can print + // out friendly names when printing a function + // type, even if they don't affect the actual + // semantic type underneath. + + FIELD(List>, paramTypes) + FIELD(RefPtr, resultType) RAW( FuncType() {} + UInt getParamCount() { return paramTypes.Count(); } + Type* getParamType(UInt index) { return paramTypes[index]; } + Type* getResultType() { return resultType; } + virtual String ToString() override; protected: virtual bool EqualsImpl(Type * type) override; diff --git a/source/slang/vm.cpp b/source/slang/vm.cpp index 65ecefc01..d3e44a947 100644 --- a/source/slang/vm.cpp +++ b/source/slang/vm.cpp @@ -11,28 +11,43 @@ namespace Slang { -// Representation of a type during VM execution -struct VMTypeImpl +struct VMValImpl { + // opcode used to construct the value uint32_t op; +}; + +// Representation of a type during VM execution +struct VMTypeImpl : VMValImpl +{ + // number of arguments to the type + uint32_t argCount; // Size and alignment of instances of this // type. UInt size; UInt alignment; + + // operands follow }; -struct VMType + +struct VMVal { - VMTypeImpl* impl; + VMValImpl* impl; - UInt getSize() { return impl->size; } - UInt getAlignment() { return impl->alignment; } + VMValImpl* getImpl() { return impl; } +}; + +struct VMType : VMVal +{ + VMTypeImpl* getImpl() { return (VMTypeImpl*) impl; } + UInt getSize() { return getImpl()->size; } + UInt getAlignment() { return getImpl()->alignment; } }; struct VMPtrTypeImpl : VMTypeImpl { VMType base; - uint32_t addressSpace; }; struct VMReg @@ -47,20 +62,20 @@ struct VMReg struct VMConst { // Type of the constant - VMType type; + VMType type; // Operand address to use void* ptr; }; -struct VMFrame; +struct VMModule; // Information about a function after it has been // loaded into the VM. struct VMFunc { // The parent module that this function belongs to - VMFrame* module; + VMModule* module; BCFunc* bcFunc; VMReg* regs; @@ -84,8 +99,14 @@ struct VMFrame // Registers are stored after this point. }; -struct VMModule : VMFunc +struct VM; + +struct VMModule { + BCModule* bcModule; + VM* vm; + void** symbols; + VMType* types; }; UInt decodeUInt(BCOp** ioPtr) @@ -93,13 +114,28 @@ UInt decodeUInt(BCOp** ioPtr) BCOp* ptr = *ioPtr; UInt value = *ptr++; - if( value >= 128 ) + if( value < 128 ) { - SLANG_UNEXPECTED("deal with this later"); + *ioPtr = ptr; + return value; } - *ioPtr = ptr; - return value; + // Slower path for variable-length encoding + + UInt result = 0; + for(;;) + { + value = value & 0x7F; + result = (result << 7) | value; + + if(value < 127) + { + *ioPtr = ptr; + return value; + } + + value = *ptr++; + } } Int decodeSInt(BCOp** ioPtr) @@ -194,9 +230,15 @@ T& decodeOperand(VMFrame* frame, BCOp** ioIP) return *decodeOperandPtr(frame, ioIP); } +VMType decodeType(VMFrame* frame, BCOp** ioIP) +{ + UInt id = decodeUInt(ioIP); + return frame->func->module->types[id]; +} + VMFunc* loadVMFunc( BCFunc* bcFunc, - VMFrame* vmModuleInstance); + VMModule* vmModule); struct VMSizeAlign { @@ -231,9 +273,30 @@ VMSizeAlign getVMSymbolSize(BCSymbol* symbol) return result; } +VMType getType( + VMModule* vmModule, + uint32_t typeID) +{ + return vmModule->types[typeID]; +} + +void* getGlobalPtr( + VMModule* vmModule, + uint32_t globalID) +{ + return vmModule->symbols[globalID]; +} + +VMType getGlobalType( + VMModule* vmModule, + uint32_t globalID) +{ + return getType(vmModule, vmModule->bcModule->symbols[globalID]->typeID); +} + VMFunc* loadVMFunc( BCFunc* bcFunc, - VMFrame* vmModuleInstance) + VMModule* vmModule) { UInt regCount = bcFunc->regCount; UInt constCount = bcFunc->constCount; @@ -245,7 +308,7 @@ VMFunc* loadVMFunc( VMReg* vmRegs = (VMReg*) (vmFunc + 1); VMConst* vmConsts = (VMConst*) (vmRegs + regCount); - vmFunc->module = vmModuleInstance; + vmFunc->module = vmModule; vmFunc->bcFunc = bcFunc; vmFunc->regs = vmRegs; vmFunc->consts = vmConsts; @@ -254,31 +317,14 @@ VMFunc* loadVMFunc( for( UInt rr = 0; rr < regCount; ++rr ) { BCReg* bcReg = &bcFunc->regs[rr]; - auto typeGlobalID = bcReg->typeGlobalID; - - // HACK: when we are loading a module itself, we might - // not yet know the size for the things it defines - // (since the module itself might define the type of - // one of its symbols), so for now we hack it and - // assume everything at module level is 16 bytes or less. - // - // TODO: this also seems like it will cause problems - // in other contexts (any time the type of a register - // would depend on an earlier instruction in the same - // scope) so this needs careful thought. - VMType vmType = { nullptr }; - UInt regSize = 16; - UInt regAlign = 8; - - if (vmModuleInstance) - { - // We expect the type to come from the outer module, so - // that we can allocate space for it as we go. - vmType = *(VMType*)getRegPtrImpl(vmModuleInstance, typeGlobalID); + auto bcTypeID = bcReg->typeID; - regSize = vmType.getSize(); - regAlign = vmType.getAlignment(); - } + // We expect the type to come from the outer module, so + // that we can allocate space for it as we go. + auto vmType = getType(vmModule, bcTypeID); + + auto regSize = vmType.getSize(); + auto regAlign = vmType.getAlignment(); offset = (offset + (regAlign-1)) & ~(regAlign-1); @@ -292,10 +338,32 @@ VMFunc* loadVMFunc( for( UInt cc = 0; cc < constCount; ++cc ) { - auto globalID = bcFunc->consts[cc].globalID; - auto globalPtr = getRegPtrImpl(vmModuleInstance, globalID); - vmFunc->consts[cc].ptr = globalPtr; - vmFunc->consts[cc].type = vmModuleInstance->func->regs[globalID].type; + BCConst bcConst = bcFunc->consts[cc]; + switch( bcConst.flavor ) + { + case kBCConstFlavor_GlobalSymbol: + { + auto globalID = bcConst.id; + vmFunc->consts[cc].ptr = &vmModule->symbols[globalID]; + vmFunc->consts[cc].type = getGlobalType(vmModule, globalID); + } + break; + + case kBCConstFlavor_Constant: + { + auto constID = bcConst.id; + auto constInfo = &vmModule->bcModule->constants[constID]; + vmFunc->consts[cc].ptr = constInfo->ptr; + #if 0 + fprintf(stderr, "CONSANT[%d] : [%p]\n", (int)cc, vmFunc->consts[cc].ptr); + fprintf(stderr, "BC [%p] : %d\n", &constInfo->ptr, (int)constInfo->ptr.rawVal); + #endif + vmFunc->consts[cc].type = getType(vmModule, constInfo->typeID); + } + break; + } + + } return vmFunc; @@ -368,7 +436,7 @@ void dumpVMFrame(VMFrame* vmFrame) case kIROp_PtrType: { - fprintf(stderr, ": Ptr = 0x%p", *(void**)regData); + fprintf(stderr, ": Ptr = [%p]", *(void**)regData); } break; @@ -416,7 +484,186 @@ struct VMThread void resumeThread( VMThread* vmThread); -VMFrame* loadVMModuleInstance( +void computeTypeSizeAlign( + VMTypeImpl* impl) +{ + UInt size = 0; + UInt alignment = 0; + switch(impl->op) + { + case kIROp_VoidType: + size = 0; + break; + + case kIROp_BoolType: + size = 1; + break; + + case kIROp_Int32Type: + case kIROp_UInt32Type: + case kIROp_Float32Type: + size = 4; + break; + + case kIROp_FuncType: + case kIROp_PtrType: + case kIROp_readWriteStructuredBufferType: + case kIROp_structuredBufferType: + size = sizeof(void*); + break; + + default: + SLANG_UNIMPLEMENTED_X("type sizing"); + impl->size = 0; + break; + } + + if(!alignment) + alignment = size; + if(!alignment) + alignment = 1; + + impl->size = size; + impl->alignment = alignment; +} + +VMType getType( + VM* vm, + VMTypeImpl* typeImpl) +{ + // TODO: need to look up an existing type that matches... + + UInt argCount = typeImpl->argCount; + UInt size = sizeof(VMTypeImpl) + argCount*sizeof(VMType); + + VMTypeImpl* impl = (VMTypeImpl*) malloc(size); + memcpy(impl, typeImpl, size); + + computeTypeSizeAlign(impl); + + VMType type; + type.impl = impl; + return type; +} + +VMVal getVal( + VMModule* vmModule, + UInt index) +{ + return vmModule->types[index]; +} + +VMType loadVMType( + VMModule* vmModule, + BCType* bcType) +{ + // Need to load type from BC format to VM + IROp op = (IROp) bcType->op; + switch(bcType->op) + { + case kIROp_PtrType: + { + // TODO: need to do some caching! + BCPtrType* bcPtrType = (BCPtrType*) bcType; + + VMPtrTypeImpl vmPtrTypeImpl; + vmPtrTypeImpl.op = op; + vmPtrTypeImpl.argCount = 1; + vmPtrTypeImpl.size = sizeof(void*); + vmPtrTypeImpl.alignment = sizeof(void*); + vmPtrTypeImpl.base = getType(vmModule, bcPtrType->valueType->id); + + auto vmPtrType = getType(vmModule->vm, &vmPtrTypeImpl); + return vmPtrType; + } + break; + + default: + { + UInt argCount = bcType->argCount; + + UInt size = sizeof(VMTypeImpl) + argCount * sizeof(VMVal); + + VMTypeImpl* impl = (VMTypeImpl*) alloca(size); + memset(impl, 0, size); + impl->op = bcType->op; + impl->argCount = argCount; + + VMVal* args = (VMVal*) (impl + 1); + for(UInt aa = 0; aa < argCount; ++aa) + { + args[aa] = getVal(vmModule, bcType->getArg(aa)->id); + } + + return getType(vmModule->vm, impl); + } + + SLANG_UNEXPECTED("unimplemented"); + return VMType(); + break; + } +} + +void* allocateImpl(VM* vm, UInt size, UInt align) +{ + void* ptr = malloc(size); + memset(ptr, 0, size); + return ptr; +} + +void* allocate(VM* vm, VMType type) +{ + return allocateImpl(vm, type.getSize(), type.getAlignment()); +} + +template +T* allocate(VM* vm) +{ + return allocateImpl(vm, sizeof(T), alignof(T)); +} + +void* loadVMSymbol( + VMModule* vmModule, + BCSymbol* bcSymbol) +{ + // Need to load type from BC format to VM + + auto vm = vmModule->vm; + + switch(bcSymbol->op) + { + case kIROp_global_var: + { + auto type = getType(vmModule, bcSymbol->typeID); + assert(type.impl->op == kIROp_PtrType); + + VMPtrTypeImpl* ptrTypeImpl = (VMPtrTypeImpl*) type.impl; + auto valueType = ptrTypeImpl->base; + + void* varValue = allocate(vm, valueType); + void** varPtr = (void**) allocate(vm, type); + + + *varPtr = varValue; + + return varPtr; + } + break; + + case kIROp_Func: + { + auto bcFunc = (BCFunc*) bcSymbol; + VMFunc* vmFunc = loadVMFunc(bcFunc, vmModule); + return vmFunc; + } + break; + + default: + return nullptr; + } +} + +VMModule* loadVMModuleInstance( VM* vm, void const* bytecode, size_t bytecodeSize) @@ -425,45 +672,63 @@ VMFrame* loadVMModuleInstance( BCModule* bcModule = bcHeader->module; - UInt vmModuleSize = sizeof(VMModule) + bcModule->regCount * sizeof(VMReg); + UInt symbolCount = bcModule->symbolCount; + UInt typeCount = bcModule->typeCount; - VMModule* vmModule = (VMModule*) loadVMFunc(bcModule, nullptr); + UInt vmModuleSize = sizeof(VMModule) + + symbolCount * sizeof(void*) + + typeCount * sizeof(VMType); - // Create a frame to store the loaded symbols, and execute it - // to initialize them. - VMFrame* vmModuleInstance = createFrame(vmModule); - vmModuleInstance->parent = nullptr; - vmModule->module = vmModuleInstance; + VMModule* vmModule = (VMModule*)malloc(vmModuleSize); + memset(vmModule, 0, vmModuleSize); + void** vmSymbols = (void**)(vmModule + 1); + VMType* vmTypes = (VMType*)(vmSymbols + symbolCount); - VMThread thread; - thread.frame = vmModuleInstance; + vmModule->bcModule = bcModule; + vmModule->vm = vm; + vmModule->symbols = vmSymbols; + vmModule->types = vmTypes; - resumeThread(&thread); + // Initialize types before symbols, since the symbols + // will all have types... + for(UInt tt = 0; tt < typeCount; ++tt) + { + BCType* bcType = bcModule->types[tt]; + vmTypes[tt] = loadVMType(vmModule, bcType); + } + + // Now we need to initialize all the VM-level symbols + // from their BC-level equivalents. + for(UInt ss = 0; ss < symbolCount; ++ss) + { + BCSymbol* bcSymbol = bcModule->symbols[ss]; + vmSymbols[ss] = loadVMSymbol(vmModule, bcSymbol); + } - return vmModuleInstance; + return vmModule; } void* findGlobalSymbolPtr( - VMFrame* moduleInstance, + VMModule* module, char const* name) { // Okay, we need to search through the available - // "registers" looking for one that gives us a - // name match. - // - BCFunc* bcFunc = moduleInstance->func->bcFunc; - UInt regCount = bcFunc->regCount; - for (UInt rr = 0; rr < regCount; ++rr) + // symbols, looking for one that gives us a name + // match. + + BCModule* bcModule = module->bcModule; + UInt symbolCount = bcModule->symbolCount; + for(UInt ss = 0; ss < symbolCount; ++ss) { - BCReg* bcReg = &bcFunc->regs[rr]; + BCSymbol* bcSymbol = bcModule->symbols[ss]; - char const* symbolName = bcReg->name; + char const* symbolName = bcSymbol->name; if (!symbolName) continue; if(strcmp(symbolName, name) == 0) - return getRegPtrImpl(moduleInstance, rr); + return getGlobalPtr(module, ss); } return nullptr; @@ -516,182 +781,11 @@ void resumeThread( auto op = (IROp) decodeUInt(&ip); switch( op ) { - case kIROp_TypeType: - { - // The type of types - Int argCount = decodeUInt(&ip); - void* arg0Ptr = decodeOperandPtr(frame, &ip); - VMType* destPtr = decodeOperandPtr(frame, &ip); - - auto typeImpl = new VMTypeImpl(); - typeImpl->op = op; - typeImpl->size = sizeof(VMType); - typeImpl->alignment = alignof(VMType); - - VMType type = { typeImpl }; - *destPtr = type; - } - break; - - case kIROp_BlockType: - case kIROp_Int32Type: - case kIROp_UInt32Type: - case kIROp_Float32Type: - case kIROp_BoolType: - case kIROp_VoidType: - case kIROp_FuncType: // TODO: we should in principle handle function types here - { - // Case to handle types without arguments. - UInt argCount = decodeUInt(&ip); - for( UInt aa = 0; aa < argCount; ++aa ) - { - void* argPtr = decodeOperandPtr(frame, &ip); - } - VMType* destPtr = decodeOperandPtr(frame, &ip); - - auto typeImpl = new VMTypeImpl(); - typeImpl->op = op; - - UInt size = 1; - UInt align = 0; - switch (op) - { - case kIROp_BlockType: size = sizeof(void*); break; - case kIROp_Int32Type: size = sizeof(int32_t); break; - case kIROp_UInt32Type: size = sizeof(uint32_t); break; - case kIROp_Float32Type: size = sizeof(float); break; - case kIROp_BoolType: size = sizeof(bool); break; - case kIROp_VoidType: size = 0; align = 1; break; - default: - break; - } - if (!align) align = size; - typeImpl->size = size; - typeImpl->alignment = align; - - VMType type = { typeImpl }; - *destPtr = type; - } - break; - - case kIROp_PtrType: - { - // Case to handle types without arguments. - UInt argCount = decodeUInt(&ip); - VMType* typeTypePtr = decodeOperandPtr(frame, &ip); - VMType baseType = decodeOperand(frame, &ip); - int32_t addressSpace = decodeOperand(frame, &ip); - VMType* destPtr = decodeOperandPtr(frame, &ip); - - - - auto typeImpl = new VMPtrTypeImpl(); - typeImpl->op = op; - typeImpl->size = sizeof(void*); - typeImpl->alignment = alignof(void*); - typeImpl->base = baseType; - typeImpl->addressSpace = addressSpace; - - VMType type = { typeImpl }; - *destPtr = type; - } - break; - - case kIROp_readWriteStructuredBufferType: - case kIROp_structuredBufferType: - { - // Case to handle types without arguments. - UInt argCount = decodeUInt(&ip); - VMType* typeTypePtr = decodeOperandPtr(frame, &ip); - VMType baseType = decodeOperand(frame, &ip); - VMType* destPtr = decodeOperandPtr(frame, &ip); - - - // TODO: give these their own representations! - auto typeImpl = new VMPtrTypeImpl(); - typeImpl->op = op; - typeImpl->base = baseType; - typeImpl->size = sizeof(void*); - typeImpl->alignment = sizeof(void*); - - VMType type = { typeImpl }; - *destPtr = type; - } - break; - - - case kIROp_IntLit: - { - VMType type = decodeOperand(frame, &ip); - UInt uVal = decodeUInt(&ip); - void* destPtr = decodeOperandPtr(frame, &ip); - - switch( type.impl->op ) - { - case kIROp_Int32Type: - *(int32_t*)destPtr = int32_t(uVal); - break; - - case kIROp_UInt32Type: - *(uint32_t*)destPtr = uint32_t(uVal); - break; - - default: - SLANG_UNEXPECTED("integer type"); - break; - } - - } - break; - - case kIROp_FloatLit: - { - VMType type = decodeOperand(frame, &ip); - - static const UInt size = sizeof(IRFloatingPointValue); - IRFloatingPointValue value; - memcpy(&value, ip, size); - ip += size; - void* destPtr = decodeOperandPtr(frame, &ip); - - switch( type.impl->op ) - { - case kIROp_Float32Type: - *(float*)destPtr = float(value); - break; - - default: - SLANG_UNEXPECTED("float type"); - break; - } - } - break; - - case kIROp_boolConst: - { - bool val = (*ip++) != 0; - bool* destPtr = decodeOperandPtr(frame, &ip); - *destPtr = val; - } - break; - - case kIROp_Func: - { - UInt nestedID = decodeUInt(&ip); - void* destPtr = decodeOperandPtr(frame, &ip); - - BCSymbol* bcSymbol = frame->func->bcFunc->nestedSymbols[nestedID]; - BCFunc* bcFunc = (BCFunc*)bcSymbol; - VMFunc* vmFunc = loadVMFunc(bcFunc, frame->func->module); - - *(VMFunc**)destPtr = vmFunc; - } - break; - case kIROp_Var: { // This instruction represents the `alloca` for a variable of some type. + VMType type = decodeType(frame, &ip); UInt argCount = decodeUInt(&ip); void* argPtrs[16] = { 0 }; for( UInt aa = 0; aa < argCount; ++aa ) @@ -714,10 +808,17 @@ void resumeThread( case kIROp_Store: { // An ordinary memory store - VMType type = decodeOperand(frame, &ip); + VMType type = decodeType(frame, &ip); void* dest = decodeOperand(frame, &ip); void* src = decodeOperandPtr(frame, &ip); +#if 0 + fprintf(stderr, "STORE *[%p] = [%p] // size: %d\n", + dest, + src, + (int) type.getSize()); +#endif + memcpy(dest, src, type.getSize()); } break; @@ -725,7 +826,7 @@ void resumeThread( case kIROp_Load: { // An ordinary memory store - VMType type = decodeOperand(frame, &ip); + VMType type = decodeType(frame, &ip); void* src = decodeOperand(frame, &ip); void* dest = decodeOperandPtr(frame, &ip); @@ -735,6 +836,7 @@ void resumeThread( case kIROp_BufferLoad: { + VMType type = decodeType(frame, &ip); UInt argCount = decodeUInt(&ip); void* argPtrs[16] = { 0 }; for( UInt aa = 0; aa < argCount; ++aa ) @@ -745,9 +847,8 @@ void resumeThread( void* dest = decodeOperandPtr(frame, &ip); - VMType type = *(VMType*)argPtrs[0]; - char* bufferData = *(char**)argPtrs[1]; - uint32_t index = *(uint32_t*)argPtrs[2]; + char* bufferData = *(char**)argPtrs[0]; + uint32_t index = *(uint32_t*)argPtrs[1]; auto size = type.getSize(); char* elementData = bufferData + index*size; @@ -757,9 +858,9 @@ void resumeThread( case kIROp_BufferStore: { + VMType resultType = decodeType(frame, &ip); UInt argCount = decodeUInt(&ip); - VMType* resultTypePtr = decodeOperandPtr(frame, &ip); char* bufferData = decodeOperand(frame, &ip); uint32_t index = decodeOperand(frame, &ip); @@ -774,8 +875,10 @@ void resumeThread( break; case kIROp_Call: { + VMType type = decodeType(frame, &ip); UInt operandCount = decodeUInt(&ip); - VMType type = decodeOperand(frame, &ip); + + // First operand is the callee function VMFunc* func = decodeOperand(frame, &ip); // Okay, we need to create a frame to prepare the call @@ -784,7 +887,7 @@ void resumeThread( // Remaining arguments should populate the // first N registers of the callee - UInt argCount = operandCount - 2; + UInt argCount = operandCount - 1; for( UInt aa = 0; aa < argCount; ++aa ) { void* argPtr = decodeOperandPtr(frame, &ip); @@ -836,8 +939,8 @@ void resumeThread( case kIROp_ReturnVal: { + VMType instType = decodeType(frame, &ip); UInt argCount = decodeUInt(&ip); - void* typePtr = decodeOperandPtr(frame, &ip); void* argPtr = decodeOperandPtr(frame, &ip); VMFrame* oldFrame = frame; @@ -868,8 +971,8 @@ void resumeThread( // For now our encoding is very regular, so we can decode without // knowing too much about an instruction... + VMType type = decodeType(frame, &ip); UInt argCount = decodeUInt(&ip); - void* typePtr = decodeOperandPtr(frame, &ip); Int destinationBlock = decodeSInt(&ip); for( UInt aa = 2; aa < argCount; ++aa ) { @@ -892,8 +995,8 @@ void resumeThread( // For now our encoding is very regular, so we can decode without // knowing too much about an instruction... + VMType type = decodeType(frame, &ip); UInt argCount = decodeUInt(&ip); - void* typePtr = decodeOperandPtr(frame, &ip); bool* condition = decodeOperandPtr(frame, &ip); Int trueBlockID = decodeSInt(&ip); Int falseBlockID = decodeSInt(&ip); @@ -917,9 +1020,9 @@ void resumeThread( // For now our encoding is very regular, so we can decode without // knowing too much about an instruction... + VMType resultType = decodeType(frame, &ip); UInt argCount = decodeUInt(&ip); void* argPtrs[16] = { 0 }; - VMType resultType = decodeOperand(frame, &ip); auto leftOpnd = decodeOperandPtrAndType(frame, &ip); auto type = leftOpnd.type; auto leftPtr = leftOpnd.ptr; @@ -942,8 +1045,8 @@ void resumeThread( case kIROp_Mul: { + VMType type = decodeType(frame, &ip); UInt argCount = decodeUInt(&ip); - VMType type = decodeOperand(frame, &ip); void* leftPtr = decodeOperandPtr(frame, &ip); void* rightPtr = decodeOperandPtr(frame, &ip); @@ -964,8 +1067,8 @@ void resumeThread( case kIROp_Sub: { + VMType type = decodeType(frame, &ip); UInt argCount = decodeUInt(&ip); - VMType type = decodeOperand(frame, &ip); void* leftPtr = decodeOperandPtr(frame, &ip); void* rightPtr = decodeOperandPtr(frame, &ip); @@ -989,6 +1092,7 @@ void resumeThread( // For now our encoding is very regular, so we can decode without // knowing too much about an instruction... + VMType type = decodeType(frame, &ip); UInt argCount = decodeUInt(&ip); void* argPtrs[16] = { 0 }; for( UInt aa = 0; aa < argCount; ++aa ) @@ -1039,7 +1143,7 @@ SLANG_API void* SlangVMModule_findGlobalSymbolPtr( char const* name) { return (SlangVMFunc*) Slang::findGlobalSymbolPtr( - (Slang::VMFrame*) module, + (Slang::VMModule*) module, name); } diff --git a/tests/ir/loop.slang.expected b/tests/ir/loop.slang.expected index a0f8de432..e34bb68cc 100644 --- a/tests/ir/loop.slang.expected +++ b/tests/ir/loop.slang.expected @@ -2,77 +2,80 @@ result code = 0 standard error = { } standard output = { -let %63 : Ptr,64>,1> = var() -let %96 : Ptr>,0> = var() -let %282 : Ptr>,0> = var() -func @_S04mainp3 : (Int32, Int32, Int32) -> Void +ir_global_var %1 : Ptr[64]>; + +ir_global_var %2 : Ptr>>; + +ir_global_var %3 : Ptr>>; + +ir_func @_S04mainp3 : (uint, uint, uint) -> void { -block %14( param %15 : Int32, - param %24 : Int32, - param %32 : Int32) -: - let %21 : Ptr = var() - store(%21, %15) - let %29 : Ptr = var() - store(%29, %24) - let %37 : Ptr = var() - store(%37, %32) - let %64 : Int32 = load(%29) - let %69 : Ptr,0> = getElementPtr(%63, %64) - let %97 : StructuredBuffer> = load(%96) - let %100 : Int32 = load(%21) - let %101 : Vec = bufferLoad(%97, %100) - store(%69, %101) - let %110 : Ptr = var() - let %118 : Int32 = construct(1) - store(%110, %118) - loop(%123, %129, %132) +block %4( + param %5 : uint, + param %6 : uint, + param %7 : uint): + let %8 : Ptr = var() + store(%8, %5) + let %9 : Ptr = var() + store(%9, %6) + let %10 : Ptr = var() + store(%10, %7) + let %11 : uint = load(%9) + let %12 : Ptr> = getElementPtr(%1, %11) + let %13 : StructuredBuffer> = load(%2) + let %14 : uint = load(%8) + let %15 : vector = bufferLoad(%13, %14) + store(%12, %15) + let %16 : Ptr = var() + let %17 : uint = construct(1) + store(%16, %17) + loop(%18, %19, %20) -block %123: - let %139 : Int32 = load(%110) - let %148 : Int32 = construct(64) - let %149 : Bool = cmpLT(%139, %148) - loopTest(%149, %126, %129) +block %18: + let %21 : uint = load(%16) + let %22 : uint = construct(64) + let %23 : bool = cmpLT(%21, %22) + loopTest(%23, %24, %19) -block %126: +block %24: GroupMemoryBarrierWithGroupSync() - let %174 : Int32 = load(%29) - let %179 : Ptr,0> = getElementPtr(%63, %174) - let %184 : Ptr,0> = var() - let %185 : Vec = load(%179) - store(%184, %185) - let %204 : Int32 = load(%29) - let %207 : Int32 = load(%110) - let %208 : Int32 = sub(%204, %207) - let %213 : Ptr,0> = getElementPtr(%63, %208) - let %214 : Vec = load(%213) - let %215 : Vec = load(%184) - let %216 : Vec = add(%215, %214) - store(%184, %216) - let %219 : Vec = load(%184) - store(%179, %219) - unconditionalBranch(%132) + let %25 : uint = load(%9) + let %26 : Ptr> = getElementPtr(%1, %25) + let %27 : Ptr> = var() + let %28 : vector = load(%26) + store(%27, %28) + let %29 : uint = load(%9) + let %30 : uint = load(%16) + let %31 : uint = sub(%29, %30) + let %32 : Ptr> = getElementPtr(%1, %31) + let %33 : vector = load(%32) + let %34 : vector = load(%27) + let %35 : vector = add(%34, %33) + store(%27, %35) + let %36 : vector = load(%27) + store(%26, %36) + unconditionalBranch(%20) -block %132: - let %232 : Ptr = var() - let %233 : Int32 = load(%110) - store(%232, %233) - let %244 : Int32 = construct(1) - let %245 : Int32 = load(%232) - let %246 : Int32 = shl(%245, %244) - store(%232, %246) - let %249 : Int32 = load(%232) - store(%110, %249) - unconditionalBranch(%123) +block %20: + let %37 : Ptr = var() + let %38 : uint = load(%16) + store(%37, %38) + let %39 : uint = construct(1) + let %40 : uint = load(%37) + let %41 : uint = shl(%40, %39) + store(%37, %41) + let %42 : uint = load(%37) + store(%16, %42) + unconditionalBranch(%18) -block %129: +block %19: GroupMemoryBarrierWithGroupSync() - let %283 : RWStructuredBuffer> = load(%282) - let %286 : Int32 = load(%21) - let %300 : Ptr,0> = getElementPtr(%63, 0) - let %301 : Vec = load(%300) - bufferStore(%283, %286, %301) + let %43 : RWStructuredBuffer> = load(%3) + let %44 : uint = load(%8) + let %45 : Ptr> = getElementPtr(%1, 0) + let %46 : vector = load(%45) + bufferStore(%43, %44, %46) return_void() } } diff --git a/tools/eval-test/main.cpp b/tools/eval-test/main.cpp index f8736a07b..9fb6f94a3 100644 --- a/tools/eval-test/main.cpp +++ b/tools/eval-test/main.cpp @@ -80,15 +80,15 @@ int main( bytecode, bytecodeSize); - SlangVMFunc* vmFunc = *(SlangVMFunc**)SlangVMModule_findGlobalSymbolPtr( + SlangVMFunc* vmFunc = (SlangVMFunc*)SlangVMModule_findGlobalSymbolPtr( vmModule, "main"); - int32_t*& inputArg = **(int32_t***)SlangVMModule_findGlobalSymbolPtr( + int32_t*& inputArg = *(int32_t**)SlangVMModule_findGlobalSymbolPtr( vmModule, "input"); - int32_t*& outputArg = **(int32_t***)SlangVMModule_findGlobalSymbolPtr( + int32_t*& outputArg = *(int32_t**)SlangVMModule_findGlobalSymbolPtr( vmModule, "output"); -- cgit v1.2.3