diff options
| author | Ellie Hermaszewska <ellieh@nvidia.com> | 2023-08-26 01:42:34 +0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-08-25 10:42:34 -0700 |
| commit | ef4c9f1f1c297f1a33be95795a7a7561e0cc3bde (patch) | |
| tree | 9ea81689432040905772aeec447adad88f212e01 /source | |
| parent | 036abc85ba1db9c8c06289f0a0492e9a95a228b9 (diff) | |
Initial version of spirv_asm block (#3151)
* Initial version of spirv_asm block
* Correct indentation of parent instruction dumping
* neater dumping for spirv_asm instructions
* Add $$ DollarDollar token
* Allow passing addresses to spirv_asm blocks
* spirv OpUndef
* String literals in spirv asm
* OpName for spirv_asm ids
* Correct failure in lower spirv_asm
* correct position for spirv_asm idents
* comment correct
* several more tests for spirv_asm blocks
* Fill out some unimplemented functions for spirv_asm expressions
---------
Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source')
| -rw-r--r-- | source/compiler-core/slang-lexer.cpp | 11 | ||||
| -rw-r--r-- | source/compiler-core/slang-token-defs.h | 1 | ||||
| -rw-r--r-- | source/slang/slang-ast-dump.cpp | 49 | ||||
| -rw-r--r-- | source/slang/slang-ast-expr.h | 37 | ||||
| -rw-r--r-- | source/slang/slang-ast-iterator.h | 11 | ||||
| -rw-r--r-- | source/slang/slang-check-expr.cpp | 60 | ||||
| -rw-r--r-- | source/slang/slang-check-impl.h | 2 | ||||
| -rw-r--r-- | source/slang/slang-diagnostic-defs.h | 7 | ||||
| -rw-r--r-- | source/slang/slang-emit-spirv.cpp | 123 | ||||
| -rw-r--r-- | source/slang/slang-ir-inst-defs.h | 20 | ||||
| -rw-r--r-- | source/slang/slang-ir-insts.h | 50 | ||||
| -rw-r--r-- | source/slang/slang-ir.cpp | 103 | ||||
| -rw-r--r-- | source/slang/slang-language-server-ast-lookup.cpp | 12 | ||||
| -rw-r--r-- | source/slang/slang-lower-to-ir.cpp | 102 | ||||
| -rw-r--r-- | source/slang/slang-parser.cpp | 120 | ||||
| -rw-r--r-- | source/slang/slang-serialize-ast-type-info.h | 3 |
16 files changed, 705 insertions, 6 deletions
diff --git a/source/compiler-core/slang-lexer.cpp b/source/compiler-core/slang-lexer.cpp index 24cd3034b..5954dc668 100644 --- a/source/compiler-core/slang-lexer.cpp +++ b/source/compiler-core/slang-lexer.cpp @@ -1314,7 +1314,16 @@ namespace Slang case '?': _advance(lexer); return TokenType::QuestionMark; case '@': _advance(lexer); return TokenType::At; - case '$': _advance(lexer); return TokenType::Dollar; + case '$': + { + _advance(lexer); + if(_peek(lexer) == '$') + { + _advance(lexer); + return TokenType::DollarDollar; + } + return TokenType::Dollar; + } } diff --git a/source/compiler-core/slang-token-defs.h b/source/compiler-core/slang-token-defs.h index 45b4912e7..2a66359fe 100644 --- a/source/compiler-core/slang-token-defs.h +++ b/source/compiler-core/slang-token-defs.h @@ -85,6 +85,7 @@ PUNCTUATION(Colon, ":") PUNCTUATION(RightArrow, "->") PUNCTUATION(At, "@") PUNCTUATION(Dollar, "$") +PUNCTUATION(DollarDollar, "$$") PUNCTUATION(Pound, "#") PUNCTUATION(PoundPound, "##") diff --git a/source/slang/slang-ast-dump.cpp b/source/slang/slang-ast-dump.cpp index d016d1c15..7dba55c52 100644 --- a/source/slang/slang-ast-dump.cpp +++ b/source/slang/slang-ast-dump.cpp @@ -644,6 +644,55 @@ struct ASTDumpContext m_writer->emit(")"); } + void dump(const SPIRVAsmOperand& operand) + { + switch(operand.flavor) + { + case SPIRVAsmOperand::Id: + m_writer->emit("%"); + break; + case SPIRVAsmOperand::Literal: + case SPIRVAsmOperand::NamedValue: + break; + case SPIRVAsmOperand::SlangValue: + m_writer->emit("$"); + break; + case SPIRVAsmOperand::SlangValueAddr: + m_writer->emit("&"); + break; + case SPIRVAsmOperand::SlangType: + m_writer->emit("$$"); + break; + default: + SLANG_UNREACHABLE("Unhandled case in ast dump for SPIRVAsmOperand"); + } + if(operand.expr) + dump(operand.expr); + else + dump(operand.token); + } + + void dump(const SPIRVAsmInst& inst) + { + dump(inst.opcode); + for(const auto& o : inst.operands) + dump(o); + } + + void dump(const SPIRVAsmExpr& expr) + { + m_writer->emit("spirv_asm\n"); + m_writer->emit("{\n"); + m_writer->indent(); + for(const auto& i : expr.insts) + { + dump(i); + m_writer->emit(";\n"); + } + m_writer->dedent(); + m_writer->emit("}"); + } + void dumpObjectFull(NodeBase* node); ASTDumpContext(SourceWriter* writer, ASTDumpUtil::Flags flags, ASTDumpUtil::Style dumpStyle): diff --git a/source/slang/slang-ast-expr.h b/source/slang/slang-ast-expr.h index 28ce2e4d1..36d304a1a 100644 --- a/source/slang/slang-ast-expr.h +++ b/source/slang/slang-ast-expr.h @@ -631,4 +631,41 @@ public: List<Val*> knownGenericArgs; }; +class SPIRVAsmOperand +{ + SLANG_VALUE_CLASS(SPIRVAsmOperand); + +public: + enum Flavor + { + Literal, // No prefix + Id, // Prefixed with % + NamedValue, // An identifier + SlangValue, + SlangValueAddr, + SlangType, + }; + Flavor flavor; + Token token; + Expr* expr = nullptr; + TypeExp type = TypeExp(); +}; + +class SPIRVAsmInst +{ + SLANG_VALUE_CLASS(SPIRVAsmInst); + +public: + SPIRVAsmOperand opcode; + List<SPIRVAsmOperand> operands; +}; + +class SPIRVAsmExpr : public Expr +{ + SLANG_AST_CLASS(SPIRVAsmExpr); + +public: + List<SPIRVAsmInst> insts; +}; + } // namespace Slang diff --git a/source/slang/slang-ast-iterator.h b/source/slang/slang-ast-iterator.h index fb3d50b4f..76effa608 100644 --- a/source/slang/slang-ast-iterator.h +++ b/source/slang/slang-ast-iterator.h @@ -282,6 +282,17 @@ struct ASTIterator { dispatchIfNotNull(expr->innerExpr); } + + void visitSPIRVAsmExpr(SPIRVAsmExpr* expr) + { + iterator->maybeDispatchCallback(expr); + for(const auto& i : expr->insts) + { + dispatchIfNotNull(i.opcode.expr); + for(const auto& o : i.operands) + dispatchIfNotNull(o.expr); + } + } }; struct ASTIteratorStmtVisitor : public StmtVisitor<ASTIteratorStmtVisitor> diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp index 3d2f81edb..5266f02f6 100644 --- a/source/slang/slang-check-expr.cpp +++ b/source/slang/slang-check-expr.cpp @@ -3895,4 +3895,64 @@ namespace Slang return expr; } + + Expr* SemanticsExprVisitor::visitSPIRVAsmExpr(SPIRVAsmExpr* expr) + { + // We will iterate over all the operands in all the insts and check + // them + for(auto& inst : expr->insts) + { + const bool isLast = &inst == &expr->insts.getLast(); + for(auto& operand : inst.operands) + { + if(operand.flavor == SPIRVAsmOperand::SlangType) + { + // This is a $$type operand, fill in the TypeExp member of the operand + TypeExp& typeExpr = operand.type; + typeExpr.exp = operand.expr; + typeExpr = CheckProperType(typeExpr); + operand.expr = typeExpr.exp; + } + else if(operand.flavor == SPIRVAsmOperand::SlangValue + || operand.flavor == SPIRVAsmOperand::SlangValueAddr) + { + // This is a $expr operand, check the expr + operand.expr = dispatch(operand.expr); + } + else if(operand.flavor == SPIRVAsmOperand::NamedValue + && operand.token.getContent() == "result") + { + // This is the <result-id> marker, check that it only + // appears in the last instruction. + + // TODO: We could consider relaxing this, because SPIR-V + // does have forward references for decorations and such + if (!isLast) + { + getSink()->diagnose(operand.token, Diagnostics::misplacedResultIdMarker); + getSink()->diagnoseWithoutSourceView(expr, Diagnostics::considerOpCopyObject); + } + } + } + } + + // Assign the type of this expression from the type of the last + // instruction, otherwise void + if(expr->insts.getCount()) + { + // TODO: we trust that this is correct, but could should verify + const auto lastOperands = expr->insts.getLast().operands; + if(lastOperands.getCount() >= 2 + && lastOperands[0].flavor == SPIRVAsmOperand::SlangType + && lastOperands[1].flavor == SPIRVAsmOperand::NamedValue + && lastOperands[1].token.getContent() == "result") + { + expr->type = lastOperands[0].type.type; + } + } + if(!expr->type) + expr->type = m_astBuilder->getVoidType(); + + return expr; + } } diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h index 8e37e7967..1e3bde4de 100644 --- a/source/slang/slang-check-impl.h +++ b/source/slang/slang-check-impl.h @@ -2424,6 +2424,8 @@ namespace Slang Expr* visitGetArrayLengthExpr(GetArrayLengthExpr* expr); + Expr* visitSPIRVAsmExpr(SPIRVAsmExpr*); + /// Perform semantic checking on a `modifier` that is being applied to the given `type` Val* checkTypeModifier(Modifier* modifier, Type* type); diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h index 0243543e9..ac6e9a932 100644 --- a/source/slang/slang-diagnostic-defs.h +++ b/source/slang/slang-diagnostic-defs.h @@ -238,9 +238,14 @@ DIAGNOSTIC(20015, Error, unknownSPIRVCapability, "unknown SPIR-V capability '$0' DIAGNOSTIC(20101, Warning, unintendedEmptyStatement, "potentially unintended empty statement at this location; use {} instead.") -// 29xxx - Snippet parsing +// 29xxx - Snippet parsing and inline asm DIAGNOSTIC(29000, Error, snippetParsingFailed, "unable to parse target intrinsic snippet: $0") +DIAGNOSTIC(29100, Error, unrecognizedSPIRVOpcode, "unrecognized spirv opcode: $0") +DIAGNOSTIC(29101, Error, misplacedResultIdMarker, "the result-id marker must only be used in the last instruction of a spriv_asm expression") +DIAGNOSTIC(29102, Note, considerOpCopyObject, "consider adding an OpCopyObject instruction to the end of the spirv_asm expression") +DIAGNOSTIC(29103, Note, noSuchAddress, "unable to take the address of this address-of asm operand") + // // 3xxxx - Semantic analysis diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp index 29114424d..0022bdd85 100644 --- a/source/slang/slang-emit-spirv.cpp +++ b/source/slang/slang-emit-spirv.cpp @@ -10,6 +10,7 @@ #include "slang-ir-spirv-snippet.h" #include "slang-ir-spirv-legalize.h" #include "slang-spirv-val.h" +#include "slang-lookup-spirv.h" #include "spirv/unified1/spirv.h" #include "../core/slang-memory-arena.h" #include <type_traits> @@ -577,13 +578,18 @@ struct SPIRVEmitContext // // We will allocate <id>s on emand as they are needed. + SpvWord freshID() + { + return m_nextID++; + } + /// Get the <id> for `inst`, or assign one if it doesn't have one yet SpvWord getID(SpvInst* inst) { auto id = inst->id; if( !id ) { - id = m_nextID++; + id = freshID(); inst->id = id; } return id; @@ -2019,7 +2025,10 @@ struct SPIRVEmitContext return emitDebugLine(parent, as<IRDebugLine>(inst)); case kIROp_GetStringHash: return emitGetStringHash(inst); - + case kIROp_undefined: + return emitOpUndef(parent, inst, inst->getDataType()); + case kIROp_SPIRVAsm: + return emitSPIRVAsm(parent, as<IRSPIRVAsm>(inst)); } } @@ -3789,6 +3798,116 @@ struct SPIRVEmitContext debugLine->getColEnd()); } + SpvInst* emitSPIRVAsm(SpvInstParent* parent, IRSPIRVAsm* inst) + { + SpvInst* last = nullptr; + + // This keeps track of the named IDs used in the asm block + Dictionary<UnownedStringSlice, SpvWord> idMap; + + for(const auto spvInst : inst->getInsts()) + { + const bool isLast = spvInst == inst->getLastChild(); + const auto opcodeString = spvInst->getOpcodeString(); + SpvOp opcode; + const bool foundOpCode = lookupSpvOp(opcodeString, opcode) + || lookupSpvOp((String("Op") + opcodeString).getUnownedSlice(), opcode); + if(!foundOpCode) + { + m_sink->diagnose( + spvInst->getOpcode(), + Diagnostics::unrecognizedSPIRVOpcode, + opcodeString + ); + return nullptr; + } + + const auto parentForOpCode = [this](SpvOp opcode, SpvInstParent* defaultParent){ + return + opcode == SpvOpConstant ? getSection(SpvLogicalSectionID::ConstantsAndTypes) + : opcode == SpvOpName ? getSection(SpvLogicalSectionID::DebugNames) + : defaultParent; + }; + + last = emitInstCustomOperandFunc( + parentForOpCode(opcode, parent), + // We want the "result instruction" to refer to the top level + // block which assumes its value, the others are free to refer + // to whatever, so just use the internal spv inst rep + // TODO: This is not correct, because the instruction which is + // assigned to result is not necessarily the last instruction + isLast ? as<IRInst>(inst) : spvInst, + opcode, + [&](){ + for(const auto operand : spvInst->getSPIRVOperands()) + { + switch(operand->getOp()) + { + case kIROp_SPIRVAsmOperandLiteral: + { + const auto v = as<IRConstant>(operand->getValue()); + SLANG_ASSERT(v); + switch(v->getOp()) + { + case kIROp_StringLit: + emitOperand(SpvLiteralBits::fromUnownedStringSlice(v->getStringSlice())); + break; + case kIROp_IntLit: + { + // TODO: range checking + const auto i = cast<IRIntLit>(v)->getValue(); + emitOperand(SpvLiteralInteger::from32(uint32_t(i))); + break; + } + default: + SLANG_UNREACHABLE("Unhandled case in emitSPIRVAsm"); + } + break; + } + case kIROp_SPIRVAsmOperandInst: + { + const auto i = operand->getValue(); + emitOperand(ensureInst(i)); + break; + } + case kIROp_SPIRVAsmOperandEnum: + { + const auto s = cast<IRStringLit>(operand->getValue())->getStringSlice(); + if(s == "result") + { + SLANG_ASSERT(isLast); + emitOperand(kResultID); + } + else + SLANG_UNIMPLEMENTED_X("lookup enum operands in spirv_asm"); + break; + } + case kIROp_SPIRVAsmOperandId: + { + const auto idName = cast<IRStringLit>(operand->getValue())->getStringSlice(); + SpvWord id; + if(!idMap.tryGetValue(idName, id)) + { + id = freshID(); + idMap.set(idName, id); + } + emitOperand(id); + break; + } + default: + SLANG_UNREACHABLE("Unhandled case in emitSPIRVAsm"); + } + } + } + ); + } + + for(const auto& [name, id] : idMap) + emitOpName(getSection(SpvLogicalSectionID::DebugNames), nullptr, id, name); + + return last; + } + OrderedHashSet<SpvCapability> m_capabilities; void requireSPIRVCapability(SpvCapability capability) diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h index d264dfd06..980770be5 100644 --- a/source/slang/slang-ir-inst-defs.h +++ b/source/slang/slang-ir-inst-defs.h @@ -1053,6 +1053,26 @@ INST(DifferentiableTypeDictionaryItem, DifferentiableTypeDictionaryItem, 0, 0) INST(DebugSource, DebugSource, 2, HOISTABLE) INST(DebugLine, DebugLine, 5, 0) +/* Inline assembly */ +INST(SPIRVAsm, SPIRVAsm, 0, PARENT) +INST(SPIRVAsmInst, SPIRVAsmInst, 1, 0) + // These instruction serve to inform the backend precisely how to emit each + // instruction, consider the difference between emitting a literal integer + // and a reference to a literal integer instruction + // + // A literal string or 32-bit integer to be passed as operands + INST(SPIRVAsmOperandLiteral, SPIRVAsmOperandLiteral, 1, 0) + // A reference to a slang IRInst, either a value or a type + INST(SPIRVAsmOperandInst, SPIRVAsmOperandInst, 1, 0) + // A named enumerator, the value of which is determined in the backend + // It can also have the value "result", indicating that the result-id of + // the asm block should be used + INST(SPIRVAsmOperandEnum, SPIRVAsmOperandEnum, 1, 0) + // A string which is given a unique ID in the backend, used to refer to + // results of other instrucions in the same asm block + INST(SPIRVAsmOperandId, SPIRVAsmOperandId, 1, 0) +INST_RANGE(SPIRVAsmOperand, SPIRVAsmOperandLiteral, SPIRVAsmOperandId) + #undef PARENT #undef USE_OTHER #undef INST_RANGE diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h index 5639d90dd..e12306c54 100644 --- a/source/slang/slang-ir-insts.h +++ b/source/slang/slang-ir-insts.h @@ -2883,6 +2883,49 @@ struct IRDebugLine : IRInst IRInst* getColEnd() { return getOperand(4); } }; +struct IRSPIRVAsmOperand : IRInst +{ + IR_PARENT_ISA(SPIRVAsmOperand); + IRInst* getValue() + { + return getOperand(0); + } +}; + +struct IRSPIRVAsmInst : IRInst +{ + IR_LEAF_ISA(SPIRVAsmInst); + + IRSPIRVAsmOperand* getOpcode() + { + // TODO: This only supports known opcodes at the moment, eventually we'll want + // another child of IRSPIRVAsm which just stores raw words + const auto opcodeOperand = cast<IRSPIRVAsmOperand>(getOperand(0)); + SLANG_ASSERT(opcodeOperand->getOp() == kIROp_SPIRVAsmOperandEnum); + return opcodeOperand; + } + + UnownedStringSlice getOpcodeString() + { + const auto opcodeOperand = getOpcode(); + const auto opcodeStringLit = cast<IRStringLit>(opcodeOperand->getValue()); + return opcodeStringLit->getStringSlice(); + } + + IROperandList<IRSPIRVAsmOperand> getSPIRVOperands() + { + return IROperandList<IRSPIRVAsmOperand>(getOperands() + 1, getOperands() + getOperandCount()); + } +}; + +struct IRSPIRVAsm : IRInst +{ + IR_LEAF_ISA(SPIRVAsm); + IRFilteredInstList<IRSPIRVAsmInst> getInsts() + { + return IRFilteredInstList<IRSPIRVAsmInst>(getFirstChild(), getLastChild()); + } +}; struct IRBuilderSourceLocRAII; @@ -3870,6 +3913,13 @@ public: IRInst* emitShr(IRType* type, IRInst* op0, IRInst* op1); IRInst* emitShl(IRType* type, IRInst* op0, IRInst* op1); + IRSPIRVAsmOperand* emitSPIRVAsmOperandLiteral(IRInst* literal); + IRSPIRVAsmOperand* emitSPIRVAsmOperandInst(IRInst* inst); + IRSPIRVAsmOperand* emitSPIRVAsmOperandId(IRInst* inst); + IRSPIRVAsmOperand* emitSPIRVAsmOperandEnum(IRInst* inst); + IRSPIRVAsmInst* emitSPIRVAsmInst(IRInst* opcode, List<IRInst*> operands); + IRSPIRVAsm* emitSPIRVAsm(IRType* type); + // // Decorations // diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp index 181970632..91a21754d 100644 --- a/source/slang/slang-ir.cpp +++ b/source/slang/slang-ir.cpp @@ -5664,6 +5664,84 @@ namespace Slang return inst; } + IRSPIRVAsmOperand* IRBuilder::emitSPIRVAsmOperandLiteral(IRInst* literal) + { + SLANG_ASSERT(as<IRSPIRVAsm>(m_insertLoc.getParent())); + const auto i = createInst<IRSPIRVAsmOperand>( + this, + kIROp_SPIRVAsmOperandLiteral, + literal->getFullType(), + literal + ); + addInst(i); + return i; + } + + IRSPIRVAsmOperand* IRBuilder::emitSPIRVAsmOperandInst(IRInst* inst) + { + SLANG_ASSERT(as<IRSPIRVAsm>(m_insertLoc.getParent())); + const auto i = createInst<IRSPIRVAsmOperand>( + this, + kIROp_SPIRVAsmOperandInst, + inst->getFullType(), + inst + ); + addInst(i); + return i; + } + + IRSPIRVAsmOperand* IRBuilder::emitSPIRVAsmOperandId(IRInst* inst) + { + SLANG_ASSERT(as<IRSPIRVAsm>(m_insertLoc.getParent())); + const auto i = createInst<IRSPIRVAsmOperand>( + this, + kIROp_SPIRVAsmOperandId, + inst->getFullType(), + inst + ); + addInst(i); + return i; + } + + IRSPIRVAsmOperand* IRBuilder::emitSPIRVAsmOperandEnum(IRInst* inst) + { + SLANG_ASSERT(as<IRSPIRVAsm>(m_insertLoc.getParent())); + const auto i = createInst<IRSPIRVAsmOperand>( + this, + kIROp_SPIRVAsmOperandEnum, + inst->getFullType(), + inst + ); + addInst(i); + return i; + } + + IRSPIRVAsmInst* IRBuilder::emitSPIRVAsmInst(IRInst* opcode, List<IRInst*> operands) + { + SLANG_ASSERT(as<IRSPIRVAsm>(m_insertLoc.getParent())); + operands.insert(0, opcode); + const auto i = createInst<IRSPIRVAsmInst>( + this, + kIROp_SPIRVAsmInst, + getVoidType(), + operands.getCount(), + operands.getBuffer() + ); + addInst(i); + return i; + } + + IRSPIRVAsm* IRBuilder::emitSPIRVAsm(IRType* type) + { + const auto asmInst = createInst<IRSPIRVAsm>( + this, + kIROp_SPIRVAsm, + type + ); + addInst(asmInst); + return asmInst; + } + // // Decorations // @@ -6158,6 +6236,9 @@ namespace Slang if(as<IRType>(inst)) return true; + if(as<IRSPIRVAsmOperand>(inst)) + return true; + return false; } @@ -6370,7 +6451,6 @@ namespace Slang { auto opInfo = getIROpInfo(inst->getOp()); - dumpIndent(context); dump(context, opInfo.name); dump(context, " "); dumpID(context, inst); @@ -6398,6 +6478,7 @@ namespace Slang } context->indent--; + dumpIndent(context); dump(context, "}\n"); } @@ -6468,6 +6549,25 @@ namespace Slang } } + // Special case the SPIR-V asm operands as the distinction here is + // clear anyway to the user + switch(op) + { + case kIROp_SPIRVAsmOperandEnum: + dumpInstExpr(context, inst->getOperand(0)); + return; + case kIROp_SPIRVAsmOperandLiteral: + dumpInstExpr(context, inst->getOperand(0)); + return; + case kIROp_SPIRVAsmOperandInst: + dumpInstExpr(context, inst->getOperand(0)); + return; + case kIROp_SPIRVAsmOperandId: + dump(context, "%"); + dumpInstExpr(context, inst->getOperand(0)); + return; + } + dump(context, opInfo.name); dumpInstOperandList(context, inst); } @@ -6501,6 +6601,7 @@ namespace Slang case kIROp_WitnessTable: case kIROp_StructType: + case kIROp_SPIRVAsm: dumpIRParentInst(context, inst); return; diff --git a/source/slang/slang-language-server-ast-lookup.cpp b/source/slang/slang-language-server-ast-lookup.cpp index d29cc4485..29661a415 100644 --- a/source/slang/slang-language-server-ast-lookup.cpp +++ b/source/slang/slang-language-server-ast-lookup.cpp @@ -405,6 +405,18 @@ public: { return dispatchIfNotNull(expr->originalExpr); } + bool visitSPIRVAsmExpr(SPIRVAsmExpr* expr) + { + for(const auto& i : expr->insts) + { + if(dispatchIfNotNull(i.opcode.expr)) + return true; + for(const auto& o : i.operands) + if(dispatchIfNotNull(o.expr)) + return true; + } + return false; + } bool visitModifiedTypeExpr(ModifiedTypeExpr* expr) { return dispatchIfNotNull(expr->base.exp); } bool visitFuncTypeExpr(FuncTypeExpr* expr) { diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp index 497981a94..854485185 100644 --- a/source/slang/slang-lower-to-ir.cpp +++ b/source/slang/slang-lower-to-ir.cpp @@ -3244,6 +3244,108 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> UNREACHABLE_RETURN(LoweredValInfo()); } + LoweredValInfo visitSPIRVAsmExpr(SPIRVAsmExpr* expr) + { + // Although the surface syntax can have an empty ASM block, the IR asm + // block must have at least one inst + if(!expr->insts.getCount()) + return LoweredValInfo{}; + + auto builder = context->irBuilder; + + const auto type = lowerType(context, expr->type); + const auto spirvAsmInst = builder->emitSPIRVAsm(type); + + const auto lowerOperand = [&](const SPIRVAsmOperand& operand) -> IRSPIRVAsmOperand* { + switch(operand.flavor) + { + case SPIRVAsmOperand::Literal: + { + if(operand.token.type == TokenType::IntegerLiteral) + { + const auto v = getIntegerLiteralValue(operand.token); + // TODO: we should sign-extend these where appropriate, + // difficult because it requires information on usage... + return builder->emitSPIRVAsmOperandLiteral( + builder->getIntValue(builder->getUIntType(), v)); + } + else if(operand.token.type == TokenType::StringLiteral) + { + const auto v = getStringLiteralTokenValue(operand.token); + return builder->emitSPIRVAsmOperandLiteral( + builder->getStringValue(v.getUnownedSlice())); + } + SLANG_UNREACHABLE("Unhandled literal type in visitSPIRVAsmExpr"); + } + case SPIRVAsmOperand::Id: + { + const auto id = operand.token.getContent(); + return builder->emitSPIRVAsmOperandId( + builder->getStringValue(id)); + } + case SPIRVAsmOperand::NamedValue: + { + const auto id = operand.token.getContent(); + return builder->emitSPIRVAsmOperandEnum( + builder->getStringValue(id)); + } + case SPIRVAsmOperand::SlangValue: + { + IRInst* i; + { + IRBuilderInsertLocScope insertScope(builder); + builder->setInsertBefore(spirvAsmInst); + i = getSimpleVal(context, lowerRValueExpr(context, operand.expr)); + } + return builder->emitSPIRVAsmOperandInst(i); + } + case SPIRVAsmOperand::SlangValueAddr: + { + IRInst* i; + { + IRBuilderInsertLocScope insertScope(builder); + builder->setInsertBefore(spirvAsmInst); + const auto addr = tryGetAddress( + context, + lowerLValueExpr(context, operand.expr), + TryGetAddressMode::Default + ); + if(addr.flavor == LoweredValInfo::Flavor::Ptr) + i = addr.val; + else + { + context->getSink()->diagnose(operand.expr, Diagnostics::noSuchAddress); + return nullptr; + } + } + return builder->emitSPIRVAsmOperandInst(i); + } + case SPIRVAsmOperand::SlangType: + { + IRInst* i; + { + IRBuilderInsertLocScope insertScope(builder); + builder->setInsertBefore(spirvAsmInst); + i = lowerType(context, operand.type.type); + } + return builder->emitSPIRVAsmOperandInst(i); + } + } + SLANG_UNREACHABLE("Unhandled case in visitSPIRVAsmExpr"); + }; + IRBuilderInsertLocScope insertScope(builder); + builder->setInsertInto(spirvAsmInst); + for(const auto& inst : expr->insts) + { + const auto opcode = lowerOperand(inst.opcode); + List<IRInst*> operands; + for(const auto& operand : inst.operands) + operands.add(lowerOperand(operand)); + builder->emitSPIRVAsmInst(opcode, operands); + } + return LoweredValInfo::simple(spirvAsmInst); + } + LoweredValInfo visitIndexExpr(IndexExpr* expr) { auto type = lowerType(context, expr->type); diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp index fabf1ffb4..442ddbce6 100644 --- a/source/slang/slang-parser.cpp +++ b/source/slang/slang-parser.cpp @@ -2,6 +2,7 @@ #include <assert.h> #include <float.h> +#include <optional> #include "slang-compiler.h" #include "slang-lookup.h" @@ -6147,6 +6148,120 @@ namespace Slang } } + static std::optional<SPIRVAsmOperand> parseSPIRVAsmOperand(Parser* parser) + { + const auto slangIdentOperand = [&](auto flavor){ + const auto tok = parser->ReadToken(TokenType::Identifier); + + VarExpr* varExpr = parser->astBuilder->create<VarExpr>(); + varExpr->scope = parser->currentScope; + varExpr->loc = tok.getLoc(); + varExpr->name = tok.getName(); + return SPIRVAsmOperand{flavor, tok, varExpr}; + }; + + // A regular identifier + if(parser->LookAheadToken(TokenType::Identifier)) + { + return SPIRVAsmOperand{SPIRVAsmOperand::NamedValue, parser->ReadToken()}; + } + // A literal integer or string + else if(parser->LookAheadToken(TokenType::IntegerLiteral) + || parser->LookAheadToken(TokenType::StringLiteral)) + { + return SPIRVAsmOperand{SPIRVAsmOperand::Literal, parser->ReadToken()}; + } + // A %foo id + else if(AdvanceIf(parser, TokenType::OpMod)) + { + if(parser->LookAheadToken(TokenType::IntegerLiteral) + || parser->LookAheadToken(TokenType::Identifier)) + { + return SPIRVAsmOperand{SPIRVAsmOperand::Id, parser->ReadToken()}; + } + } + // A &foo variable reference (for the address of foo) + else if(AdvanceIf(parser, TokenType::OpBitAnd)) + { + return slangIdentOperand(SPIRVAsmOperand::SlangValueAddr); + } + // A $foo variable + else if(AdvanceIf(parser, TokenType::Dollar)) + { + return slangIdentOperand(SPIRVAsmOperand::SlangValue); + } + // A $$foo type + else if(AdvanceIf(parser, TokenType::DollarDollar)) + { + return slangIdentOperand(SPIRVAsmOperand::SlangType); + } + + Unexpected(parser); + return std::nullopt; + } + + static std::optional<SPIRVAsmInst> parseSPIRVAsmInst(Parser* parser) + { + SPIRVAsmInst ret; + + const auto resultOrOpcode = parseSPIRVAsmOperand(parser); + if(!resultOrOpcode) + return std::nullopt; + + // We can enable this when we have a way of determining the index of + // the result id operand to each instruction, otherwise we don't know + // at which position in the operand list to insert this. +#if 0 + if(AdvanceIf(parser, TokenType::OpEql)) + { + const auto opcode = parseSPIRVAsmOperand(parser); + if(!opcode) + return std::nullopt; + ret.opcode = *opcode; + ret.operands.insert(???, *resultOrOpcode); + } + else +#endif + { + ret.opcode = *resultOrOpcode; + } + + // TODO: diagnose wrong opcode flavor here + + while(!(parser->LookAheadToken(TokenType::RBrace) + || parser->LookAheadToken(TokenType::Semicolon))) + { + if(const auto operand = parseSPIRVAsmOperand(parser)) + ret.operands.add(*operand); + else + return std::nullopt; + } + + return ret; + } + + static Expr* parseSPIRVAsmExpr(Parser* parser) + { + SPIRVAsmExpr* asmExpr = parser->astBuilder->create<SPIRVAsmExpr>(); + + parser->ReadToken(TokenType::LBrace); + while(!parser->tokenReader.isAtEnd()) + { + if(parser->LookAheadToken(TokenType::RBrace)) + break; + if(const auto inst = parseSPIRVAsmInst(parser)) + asmExpr->insts.add(*inst); + else + return nullptr; + if(parser->LookAheadToken(TokenType::RBrace)) + break; + parser->ReadToken(TokenType::Semicolon); + } + parser->ReadToken(TokenType::RBrace); + + return asmExpr; + } + static Expr* parsePrefixExpr(Parser* parser) { auto tokenType = peekTokenType(parser); @@ -6179,7 +6294,10 @@ namespace Slang } return newExpr; } - + else if (AdvanceIf(parser, "spirv_asm")) + { + return parseSPIRVAsmExpr(parser); + } return parsePostfixExpr(parser); } diff --git a/source/slang/slang-serialize-ast-type-info.h b/source/slang/slang-serialize-ast-type-info.h index 5ccf9ea54..f5d636b01 100644 --- a/source/slang/slang-serialize-ast-type-info.h +++ b/source/slang/slang-serialize-ast-type-info.h @@ -237,6 +237,9 @@ struct SerialTypeInfo<RequirementWitness::Flavor> : public SerialConvertTypeInfo // RequirementWitness SLANG_VALUE_TYPE_INFO(RequirementWitness) +// SPIRVAsm +SLANG_VALUE_TYPE_INFO(SPIRVAsmOperand) +SLANG_VALUE_TYPE_INFO(SPIRVAsmInst) } // namespace Slang |
