summaryrefslogtreecommitdiff
path: root/source
diff options
context:
space:
mode:
authorEllie Hermaszewska <ellieh@nvidia.com>2023-08-26 01:42:34 +0800
committerGitHub <noreply@github.com>2023-08-25 10:42:34 -0700
commitef4c9f1f1c297f1a33be95795a7a7561e0cc3bde (patch)
tree9ea81689432040905772aeec447adad88f212e01 /source
parent036abc85ba1db9c8c06289f0a0492e9a95a228b9 (diff)
Initial version of spirv_asm block (#3151)
* Initial version of spirv_asm block * Correct indentation of parent instruction dumping * neater dumping for spirv_asm instructions * Add $$ DollarDollar token * Allow passing addresses to spirv_asm blocks * spirv OpUndef * String literals in spirv asm * OpName for spirv_asm ids * Correct failure in lower spirv_asm * correct position for spirv_asm idents * comment correct * several more tests for spirv_asm blocks * Fill out some unimplemented functions for spirv_asm expressions --------- Co-authored-by: Yong He <yonghe@outlook.com>
Diffstat (limited to 'source')
-rw-r--r--source/compiler-core/slang-lexer.cpp11
-rw-r--r--source/compiler-core/slang-token-defs.h1
-rw-r--r--source/slang/slang-ast-dump.cpp49
-rw-r--r--source/slang/slang-ast-expr.h37
-rw-r--r--source/slang/slang-ast-iterator.h11
-rw-r--r--source/slang/slang-check-expr.cpp60
-rw-r--r--source/slang/slang-check-impl.h2
-rw-r--r--source/slang/slang-diagnostic-defs.h7
-rw-r--r--source/slang/slang-emit-spirv.cpp123
-rw-r--r--source/slang/slang-ir-inst-defs.h20
-rw-r--r--source/slang/slang-ir-insts.h50
-rw-r--r--source/slang/slang-ir.cpp103
-rw-r--r--source/slang/slang-language-server-ast-lookup.cpp12
-rw-r--r--source/slang/slang-lower-to-ir.cpp102
-rw-r--r--source/slang/slang-parser.cpp120
-rw-r--r--source/slang/slang-serialize-ast-type-info.h3
16 files changed, 705 insertions, 6 deletions
diff --git a/source/compiler-core/slang-lexer.cpp b/source/compiler-core/slang-lexer.cpp
index 24cd3034b..5954dc668 100644
--- a/source/compiler-core/slang-lexer.cpp
+++ b/source/compiler-core/slang-lexer.cpp
@@ -1314,7 +1314,16 @@ namespace Slang
case '?': _advance(lexer); return TokenType::QuestionMark;
case '@': _advance(lexer); return TokenType::At;
- case '$': _advance(lexer); return TokenType::Dollar;
+ case '$':
+ {
+ _advance(lexer);
+ if(_peek(lexer) == '$')
+ {
+ _advance(lexer);
+ return TokenType::DollarDollar;
+ }
+ return TokenType::Dollar;
+ }
}
diff --git a/source/compiler-core/slang-token-defs.h b/source/compiler-core/slang-token-defs.h
index 45b4912e7..2a66359fe 100644
--- a/source/compiler-core/slang-token-defs.h
+++ b/source/compiler-core/slang-token-defs.h
@@ -85,6 +85,7 @@ PUNCTUATION(Colon, ":")
PUNCTUATION(RightArrow, "->")
PUNCTUATION(At, "@")
PUNCTUATION(Dollar, "$")
+PUNCTUATION(DollarDollar, "$$")
PUNCTUATION(Pound, "#")
PUNCTUATION(PoundPound, "##")
diff --git a/source/slang/slang-ast-dump.cpp b/source/slang/slang-ast-dump.cpp
index d016d1c15..7dba55c52 100644
--- a/source/slang/slang-ast-dump.cpp
+++ b/source/slang/slang-ast-dump.cpp
@@ -644,6 +644,55 @@ struct ASTDumpContext
m_writer->emit(")");
}
+ void dump(const SPIRVAsmOperand& operand)
+ {
+ switch(operand.flavor)
+ {
+ case SPIRVAsmOperand::Id:
+ m_writer->emit("%");
+ break;
+ case SPIRVAsmOperand::Literal:
+ case SPIRVAsmOperand::NamedValue:
+ break;
+ case SPIRVAsmOperand::SlangValue:
+ m_writer->emit("$");
+ break;
+ case SPIRVAsmOperand::SlangValueAddr:
+ m_writer->emit("&");
+ break;
+ case SPIRVAsmOperand::SlangType:
+ m_writer->emit("$$");
+ break;
+ default:
+ SLANG_UNREACHABLE("Unhandled case in ast dump for SPIRVAsmOperand");
+ }
+ if(operand.expr)
+ dump(operand.expr);
+ else
+ dump(operand.token);
+ }
+
+ void dump(const SPIRVAsmInst& inst)
+ {
+ dump(inst.opcode);
+ for(const auto& o : inst.operands)
+ dump(o);
+ }
+
+ void dump(const SPIRVAsmExpr& expr)
+ {
+ m_writer->emit("spirv_asm\n");
+ m_writer->emit("{\n");
+ m_writer->indent();
+ for(const auto& i : expr.insts)
+ {
+ dump(i);
+ m_writer->emit(";\n");
+ }
+ m_writer->dedent();
+ m_writer->emit("}");
+ }
+
void dumpObjectFull(NodeBase* node);
ASTDumpContext(SourceWriter* writer, ASTDumpUtil::Flags flags, ASTDumpUtil::Style dumpStyle):
diff --git a/source/slang/slang-ast-expr.h b/source/slang/slang-ast-expr.h
index 28ce2e4d1..36d304a1a 100644
--- a/source/slang/slang-ast-expr.h
+++ b/source/slang/slang-ast-expr.h
@@ -631,4 +631,41 @@ public:
List<Val*> knownGenericArgs;
};
+class SPIRVAsmOperand
+{
+ SLANG_VALUE_CLASS(SPIRVAsmOperand);
+
+public:
+ enum Flavor
+ {
+ Literal, // No prefix
+ Id, // Prefixed with %
+ NamedValue, // An identifier
+ SlangValue,
+ SlangValueAddr,
+ SlangType,
+ };
+ Flavor flavor;
+ Token token;
+ Expr* expr = nullptr;
+ TypeExp type = TypeExp();
+};
+
+class SPIRVAsmInst
+{
+ SLANG_VALUE_CLASS(SPIRVAsmInst);
+
+public:
+ SPIRVAsmOperand opcode;
+ List<SPIRVAsmOperand> operands;
+};
+
+class SPIRVAsmExpr : public Expr
+{
+ SLANG_AST_CLASS(SPIRVAsmExpr);
+
+public:
+ List<SPIRVAsmInst> insts;
+};
+
} // namespace Slang
diff --git a/source/slang/slang-ast-iterator.h b/source/slang/slang-ast-iterator.h
index fb3d50b4f..76effa608 100644
--- a/source/slang/slang-ast-iterator.h
+++ b/source/slang/slang-ast-iterator.h
@@ -282,6 +282,17 @@ struct ASTIterator
{
dispatchIfNotNull(expr->innerExpr);
}
+
+ void visitSPIRVAsmExpr(SPIRVAsmExpr* expr)
+ {
+ iterator->maybeDispatchCallback(expr);
+ for(const auto& i : expr->insts)
+ {
+ dispatchIfNotNull(i.opcode.expr);
+ for(const auto& o : i.operands)
+ dispatchIfNotNull(o.expr);
+ }
+ }
};
struct ASTIteratorStmtVisitor : public StmtVisitor<ASTIteratorStmtVisitor>
diff --git a/source/slang/slang-check-expr.cpp b/source/slang/slang-check-expr.cpp
index 3d2f81edb..5266f02f6 100644
--- a/source/slang/slang-check-expr.cpp
+++ b/source/slang/slang-check-expr.cpp
@@ -3895,4 +3895,64 @@ namespace Slang
return expr;
}
+
+ Expr* SemanticsExprVisitor::visitSPIRVAsmExpr(SPIRVAsmExpr* expr)
+ {
+ // We will iterate over all the operands in all the insts and check
+ // them
+ for(auto& inst : expr->insts)
+ {
+ const bool isLast = &inst == &expr->insts.getLast();
+ for(auto& operand : inst.operands)
+ {
+ if(operand.flavor == SPIRVAsmOperand::SlangType)
+ {
+ // This is a $$type operand, fill in the TypeExp member of the operand
+ TypeExp& typeExpr = operand.type;
+ typeExpr.exp = operand.expr;
+ typeExpr = CheckProperType(typeExpr);
+ operand.expr = typeExpr.exp;
+ }
+ else if(operand.flavor == SPIRVAsmOperand::SlangValue
+ || operand.flavor == SPIRVAsmOperand::SlangValueAddr)
+ {
+ // This is a $expr operand, check the expr
+ operand.expr = dispatch(operand.expr);
+ }
+ else if(operand.flavor == SPIRVAsmOperand::NamedValue
+ && operand.token.getContent() == "result")
+ {
+ // This is the <result-id> marker, check that it only
+ // appears in the last instruction.
+
+ // TODO: We could consider relaxing this, because SPIR-V
+ // does have forward references for decorations and such
+ if (!isLast)
+ {
+ getSink()->diagnose(operand.token, Diagnostics::misplacedResultIdMarker);
+ getSink()->diagnoseWithoutSourceView(expr, Diagnostics::considerOpCopyObject);
+ }
+ }
+ }
+ }
+
+ // Assign the type of this expression from the type of the last
+ // instruction, otherwise void
+ if(expr->insts.getCount())
+ {
+ // TODO: we trust that this is correct, but could should verify
+ const auto lastOperands = expr->insts.getLast().operands;
+ if(lastOperands.getCount() >= 2
+ && lastOperands[0].flavor == SPIRVAsmOperand::SlangType
+ && lastOperands[1].flavor == SPIRVAsmOperand::NamedValue
+ && lastOperands[1].token.getContent() == "result")
+ {
+ expr->type = lastOperands[0].type.type;
+ }
+ }
+ if(!expr->type)
+ expr->type = m_astBuilder->getVoidType();
+
+ return expr;
+ }
}
diff --git a/source/slang/slang-check-impl.h b/source/slang/slang-check-impl.h
index 8e37e7967..1e3bde4de 100644
--- a/source/slang/slang-check-impl.h
+++ b/source/slang/slang-check-impl.h
@@ -2424,6 +2424,8 @@ namespace Slang
Expr* visitGetArrayLengthExpr(GetArrayLengthExpr* expr);
+ Expr* visitSPIRVAsmExpr(SPIRVAsmExpr*);
+
/// Perform semantic checking on a `modifier` that is being applied to the given `type`
Val* checkTypeModifier(Modifier* modifier, Type* type);
diff --git a/source/slang/slang-diagnostic-defs.h b/source/slang/slang-diagnostic-defs.h
index 0243543e9..ac6e9a932 100644
--- a/source/slang/slang-diagnostic-defs.h
+++ b/source/slang/slang-diagnostic-defs.h
@@ -238,9 +238,14 @@ DIAGNOSTIC(20015, Error, unknownSPIRVCapability, "unknown SPIR-V capability '$0'
DIAGNOSTIC(20101, Warning, unintendedEmptyStatement, "potentially unintended empty statement at this location; use {} instead.")
-// 29xxx - Snippet parsing
+// 29xxx - Snippet parsing and inline asm
DIAGNOSTIC(29000, Error, snippetParsingFailed, "unable to parse target intrinsic snippet: $0")
+DIAGNOSTIC(29100, Error, unrecognizedSPIRVOpcode, "unrecognized spirv opcode: $0")
+DIAGNOSTIC(29101, Error, misplacedResultIdMarker, "the result-id marker must only be used in the last instruction of a spriv_asm expression")
+DIAGNOSTIC(29102, Note, considerOpCopyObject, "consider adding an OpCopyObject instruction to the end of the spirv_asm expression")
+DIAGNOSTIC(29103, Note, noSuchAddress, "unable to take the address of this address-of asm operand")
+
//
// 3xxxx - Semantic analysis
diff --git a/source/slang/slang-emit-spirv.cpp b/source/slang/slang-emit-spirv.cpp
index 29114424d..0022bdd85 100644
--- a/source/slang/slang-emit-spirv.cpp
+++ b/source/slang/slang-emit-spirv.cpp
@@ -10,6 +10,7 @@
#include "slang-ir-spirv-snippet.h"
#include "slang-ir-spirv-legalize.h"
#include "slang-spirv-val.h"
+#include "slang-lookup-spirv.h"
#include "spirv/unified1/spirv.h"
#include "../core/slang-memory-arena.h"
#include <type_traits>
@@ -577,13 +578,18 @@ struct SPIRVEmitContext
//
// We will allocate <id>s on emand as they are needed.
+ SpvWord freshID()
+ {
+ return m_nextID++;
+ }
+
/// Get the <id> for `inst`, or assign one if it doesn't have one yet
SpvWord getID(SpvInst* inst)
{
auto id = inst->id;
if( !id )
{
- id = m_nextID++;
+ id = freshID();
inst->id = id;
}
return id;
@@ -2019,7 +2025,10 @@ struct SPIRVEmitContext
return emitDebugLine(parent, as<IRDebugLine>(inst));
case kIROp_GetStringHash:
return emitGetStringHash(inst);
-
+ case kIROp_undefined:
+ return emitOpUndef(parent, inst, inst->getDataType());
+ case kIROp_SPIRVAsm:
+ return emitSPIRVAsm(parent, as<IRSPIRVAsm>(inst));
}
}
@@ -3789,6 +3798,116 @@ struct SPIRVEmitContext
debugLine->getColEnd());
}
+ SpvInst* emitSPIRVAsm(SpvInstParent* parent, IRSPIRVAsm* inst)
+ {
+ SpvInst* last = nullptr;
+
+ // This keeps track of the named IDs used in the asm block
+ Dictionary<UnownedStringSlice, SpvWord> idMap;
+
+ for(const auto spvInst : inst->getInsts())
+ {
+ const bool isLast = spvInst == inst->getLastChild();
+ const auto opcodeString = spvInst->getOpcodeString();
+ SpvOp opcode;
+ const bool foundOpCode = lookupSpvOp(opcodeString, opcode)
+ || lookupSpvOp((String("Op") + opcodeString).getUnownedSlice(), opcode);
+ if(!foundOpCode)
+ {
+ m_sink->diagnose(
+ spvInst->getOpcode(),
+ Diagnostics::unrecognizedSPIRVOpcode,
+ opcodeString
+ );
+ return nullptr;
+ }
+
+ const auto parentForOpCode = [this](SpvOp opcode, SpvInstParent* defaultParent){
+ return
+ opcode == SpvOpConstant ? getSection(SpvLogicalSectionID::ConstantsAndTypes)
+ : opcode == SpvOpName ? getSection(SpvLogicalSectionID::DebugNames)
+ : defaultParent;
+ };
+
+ last = emitInstCustomOperandFunc(
+ parentForOpCode(opcode, parent),
+ // We want the "result instruction" to refer to the top level
+ // block which assumes its value, the others are free to refer
+ // to whatever, so just use the internal spv inst rep
+ // TODO: This is not correct, because the instruction which is
+ // assigned to result is not necessarily the last instruction
+ isLast ? as<IRInst>(inst) : spvInst,
+ opcode,
+ [&](){
+ for(const auto operand : spvInst->getSPIRVOperands())
+ {
+ switch(operand->getOp())
+ {
+ case kIROp_SPIRVAsmOperandLiteral:
+ {
+ const auto v = as<IRConstant>(operand->getValue());
+ SLANG_ASSERT(v);
+ switch(v->getOp())
+ {
+ case kIROp_StringLit:
+ emitOperand(SpvLiteralBits::fromUnownedStringSlice(v->getStringSlice()));
+ break;
+ case kIROp_IntLit:
+ {
+ // TODO: range checking
+ const auto i = cast<IRIntLit>(v)->getValue();
+ emitOperand(SpvLiteralInteger::from32(uint32_t(i)));
+ break;
+ }
+ default:
+ SLANG_UNREACHABLE("Unhandled case in emitSPIRVAsm");
+ }
+ break;
+ }
+ case kIROp_SPIRVAsmOperandInst:
+ {
+ const auto i = operand->getValue();
+ emitOperand(ensureInst(i));
+ break;
+ }
+ case kIROp_SPIRVAsmOperandEnum:
+ {
+ const auto s = cast<IRStringLit>(operand->getValue())->getStringSlice();
+ if(s == "result")
+ {
+ SLANG_ASSERT(isLast);
+ emitOperand(kResultID);
+ }
+ else
+ SLANG_UNIMPLEMENTED_X("lookup enum operands in spirv_asm");
+ break;
+ }
+ case kIROp_SPIRVAsmOperandId:
+ {
+ const auto idName = cast<IRStringLit>(operand->getValue())->getStringSlice();
+ SpvWord id;
+ if(!idMap.tryGetValue(idName, id))
+ {
+ id = freshID();
+ idMap.set(idName, id);
+ }
+ emitOperand(id);
+ break;
+ }
+ default:
+ SLANG_UNREACHABLE("Unhandled case in emitSPIRVAsm");
+ }
+ }
+ }
+ );
+ }
+
+ for(const auto& [name, id] : idMap)
+ emitOpName(getSection(SpvLogicalSectionID::DebugNames), nullptr, id, name);
+
+ return last;
+ }
+
OrderedHashSet<SpvCapability> m_capabilities;
void requireSPIRVCapability(SpvCapability capability)
diff --git a/source/slang/slang-ir-inst-defs.h b/source/slang/slang-ir-inst-defs.h
index d264dfd06..980770be5 100644
--- a/source/slang/slang-ir-inst-defs.h
+++ b/source/slang/slang-ir-inst-defs.h
@@ -1053,6 +1053,26 @@ INST(DifferentiableTypeDictionaryItem, DifferentiableTypeDictionaryItem, 0, 0)
INST(DebugSource, DebugSource, 2, HOISTABLE)
INST(DebugLine, DebugLine, 5, 0)
+/* Inline assembly */
+INST(SPIRVAsm, SPIRVAsm, 0, PARENT)
+INST(SPIRVAsmInst, SPIRVAsmInst, 1, 0)
+ // These instruction serve to inform the backend precisely how to emit each
+ // instruction, consider the difference between emitting a literal integer
+ // and a reference to a literal integer instruction
+ //
+ // A literal string or 32-bit integer to be passed as operands
+ INST(SPIRVAsmOperandLiteral, SPIRVAsmOperandLiteral, 1, 0)
+ // A reference to a slang IRInst, either a value or a type
+ INST(SPIRVAsmOperandInst, SPIRVAsmOperandInst, 1, 0)
+ // A named enumerator, the value of which is determined in the backend
+ // It can also have the value "result", indicating that the result-id of
+ // the asm block should be used
+ INST(SPIRVAsmOperandEnum, SPIRVAsmOperandEnum, 1, 0)
+ // A string which is given a unique ID in the backend, used to refer to
+ // results of other instrucions in the same asm block
+ INST(SPIRVAsmOperandId, SPIRVAsmOperandId, 1, 0)
+INST_RANGE(SPIRVAsmOperand, SPIRVAsmOperandLiteral, SPIRVAsmOperandId)
+
#undef PARENT
#undef USE_OTHER
#undef INST_RANGE
diff --git a/source/slang/slang-ir-insts.h b/source/slang/slang-ir-insts.h
index 5639d90dd..e12306c54 100644
--- a/source/slang/slang-ir-insts.h
+++ b/source/slang/slang-ir-insts.h
@@ -2883,6 +2883,49 @@ struct IRDebugLine : IRInst
IRInst* getColEnd() { return getOperand(4); }
};
+struct IRSPIRVAsmOperand : IRInst
+{
+ IR_PARENT_ISA(SPIRVAsmOperand);
+ IRInst* getValue()
+ {
+ return getOperand(0);
+ }
+};
+
+struct IRSPIRVAsmInst : IRInst
+{
+ IR_LEAF_ISA(SPIRVAsmInst);
+
+ IRSPIRVAsmOperand* getOpcode()
+ {
+ // TODO: This only supports known opcodes at the moment, eventually we'll want
+ // another child of IRSPIRVAsm which just stores raw words
+ const auto opcodeOperand = cast<IRSPIRVAsmOperand>(getOperand(0));
+ SLANG_ASSERT(opcodeOperand->getOp() == kIROp_SPIRVAsmOperandEnum);
+ return opcodeOperand;
+ }
+
+ UnownedStringSlice getOpcodeString()
+ {
+ const auto opcodeOperand = getOpcode();
+ const auto opcodeStringLit = cast<IRStringLit>(opcodeOperand->getValue());
+ return opcodeStringLit->getStringSlice();
+ }
+
+ IROperandList<IRSPIRVAsmOperand> getSPIRVOperands()
+ {
+ return IROperandList<IRSPIRVAsmOperand>(getOperands() + 1, getOperands() + getOperandCount());
+ }
+};
+
+struct IRSPIRVAsm : IRInst
+{
+ IR_LEAF_ISA(SPIRVAsm);
+ IRFilteredInstList<IRSPIRVAsmInst> getInsts()
+ {
+ return IRFilteredInstList<IRSPIRVAsmInst>(getFirstChild(), getLastChild());
+ }
+};
struct IRBuilderSourceLocRAII;
@@ -3870,6 +3913,13 @@ public:
IRInst* emitShr(IRType* type, IRInst* op0, IRInst* op1);
IRInst* emitShl(IRType* type, IRInst* op0, IRInst* op1);
+ IRSPIRVAsmOperand* emitSPIRVAsmOperandLiteral(IRInst* literal);
+ IRSPIRVAsmOperand* emitSPIRVAsmOperandInst(IRInst* inst);
+ IRSPIRVAsmOperand* emitSPIRVAsmOperandId(IRInst* inst);
+ IRSPIRVAsmOperand* emitSPIRVAsmOperandEnum(IRInst* inst);
+ IRSPIRVAsmInst* emitSPIRVAsmInst(IRInst* opcode, List<IRInst*> operands);
+ IRSPIRVAsm* emitSPIRVAsm(IRType* type);
+
//
// Decorations
//
diff --git a/source/slang/slang-ir.cpp b/source/slang/slang-ir.cpp
index 181970632..91a21754d 100644
--- a/source/slang/slang-ir.cpp
+++ b/source/slang/slang-ir.cpp
@@ -5664,6 +5664,84 @@ namespace Slang
return inst;
}
+ IRSPIRVAsmOperand* IRBuilder::emitSPIRVAsmOperandLiteral(IRInst* literal)
+ {
+ SLANG_ASSERT(as<IRSPIRVAsm>(m_insertLoc.getParent()));
+ const auto i = createInst<IRSPIRVAsmOperand>(
+ this,
+ kIROp_SPIRVAsmOperandLiteral,
+ literal->getFullType(),
+ literal
+ );
+ addInst(i);
+ return i;
+ }
+
+ IRSPIRVAsmOperand* IRBuilder::emitSPIRVAsmOperandInst(IRInst* inst)
+ {
+ SLANG_ASSERT(as<IRSPIRVAsm>(m_insertLoc.getParent()));
+ const auto i = createInst<IRSPIRVAsmOperand>(
+ this,
+ kIROp_SPIRVAsmOperandInst,
+ inst->getFullType(),
+ inst
+ );
+ addInst(i);
+ return i;
+ }
+
+ IRSPIRVAsmOperand* IRBuilder::emitSPIRVAsmOperandId(IRInst* inst)
+ {
+ SLANG_ASSERT(as<IRSPIRVAsm>(m_insertLoc.getParent()));
+ const auto i = createInst<IRSPIRVAsmOperand>(
+ this,
+ kIROp_SPIRVAsmOperandId,
+ inst->getFullType(),
+ inst
+ );
+ addInst(i);
+ return i;
+ }
+
+ IRSPIRVAsmOperand* IRBuilder::emitSPIRVAsmOperandEnum(IRInst* inst)
+ {
+ SLANG_ASSERT(as<IRSPIRVAsm>(m_insertLoc.getParent()));
+ const auto i = createInst<IRSPIRVAsmOperand>(
+ this,
+ kIROp_SPIRVAsmOperandEnum,
+ inst->getFullType(),
+ inst
+ );
+ addInst(i);
+ return i;
+ }
+
+ IRSPIRVAsmInst* IRBuilder::emitSPIRVAsmInst(IRInst* opcode, List<IRInst*> operands)
+ {
+ SLANG_ASSERT(as<IRSPIRVAsm>(m_insertLoc.getParent()));
+ operands.insert(0, opcode);
+ const auto i = createInst<IRSPIRVAsmInst>(
+ this,
+ kIROp_SPIRVAsmInst,
+ getVoidType(),
+ operands.getCount(),
+ operands.getBuffer()
+ );
+ addInst(i);
+ return i;
+ }
+
+ IRSPIRVAsm* IRBuilder::emitSPIRVAsm(IRType* type)
+ {
+ const auto asmInst = createInst<IRSPIRVAsm>(
+ this,
+ kIROp_SPIRVAsm,
+ type
+ );
+ addInst(asmInst);
+ return asmInst;
+ }
+
//
// Decorations
//
@@ -6158,6 +6236,9 @@ namespace Slang
if(as<IRType>(inst))
return true;
+ if(as<IRSPIRVAsmOperand>(inst))
+ return true;
+
return false;
}
@@ -6370,7 +6451,6 @@ namespace Slang
{
auto opInfo = getIROpInfo(inst->getOp());
- dumpIndent(context);
dump(context, opInfo.name);
dump(context, " ");
dumpID(context, inst);
@@ -6398,6 +6478,7 @@ namespace Slang
}
context->indent--;
+ dumpIndent(context);
dump(context, "}\n");
}
@@ -6468,6 +6549,25 @@ namespace Slang
}
}
+ // Special case the SPIR-V asm operands as the distinction here is
+ // clear anyway to the user
+ switch(op)
+ {
+ case kIROp_SPIRVAsmOperandEnum:
+ dumpInstExpr(context, inst->getOperand(0));
+ return;
+ case kIROp_SPIRVAsmOperandLiteral:
+ dumpInstExpr(context, inst->getOperand(0));
+ return;
+ case kIROp_SPIRVAsmOperandInst:
+ dumpInstExpr(context, inst->getOperand(0));
+ return;
+ case kIROp_SPIRVAsmOperandId:
+ dump(context, "%");
+ dumpInstExpr(context, inst->getOperand(0));
+ return;
+ }
+
dump(context, opInfo.name);
dumpInstOperandList(context, inst);
}
@@ -6501,6 +6601,7 @@ namespace Slang
case kIROp_WitnessTable:
case kIROp_StructType:
+ case kIROp_SPIRVAsm:
dumpIRParentInst(context, inst);
return;
diff --git a/source/slang/slang-language-server-ast-lookup.cpp b/source/slang/slang-language-server-ast-lookup.cpp
index d29cc4485..29661a415 100644
--- a/source/slang/slang-language-server-ast-lookup.cpp
+++ b/source/slang/slang-language-server-ast-lookup.cpp
@@ -405,6 +405,18 @@ public:
{
return dispatchIfNotNull(expr->originalExpr);
}
+ bool visitSPIRVAsmExpr(SPIRVAsmExpr* expr)
+ {
+ for(const auto& i : expr->insts)
+ {
+ if(dispatchIfNotNull(i.opcode.expr))
+ return true;
+ for(const auto& o : i.operands)
+ if(dispatchIfNotNull(o.expr))
+ return true;
+ }
+ return false;
+ }
bool visitModifiedTypeExpr(ModifiedTypeExpr* expr) { return dispatchIfNotNull(expr->base.exp); }
bool visitFuncTypeExpr(FuncTypeExpr* expr)
{
diff --git a/source/slang/slang-lower-to-ir.cpp b/source/slang/slang-lower-to-ir.cpp
index 497981a94..854485185 100644
--- a/source/slang/slang-lower-to-ir.cpp
+++ b/source/slang/slang-lower-to-ir.cpp
@@ -3244,6 +3244,108 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo>
UNREACHABLE_RETURN(LoweredValInfo());
}
+ LoweredValInfo visitSPIRVAsmExpr(SPIRVAsmExpr* expr)
+ {
+ // Although the surface syntax can have an empty ASM block, the IR asm
+ // block must have at least one inst
+ if(!expr->insts.getCount())
+ return LoweredValInfo{};
+
+ auto builder = context->irBuilder;
+
+ const auto type = lowerType(context, expr->type);
+ const auto spirvAsmInst = builder->emitSPIRVAsm(type);
+
+ const auto lowerOperand = [&](const SPIRVAsmOperand& operand) -> IRSPIRVAsmOperand* {
+ switch(operand.flavor)
+ {
+ case SPIRVAsmOperand::Literal:
+ {
+ if(operand.token.type == TokenType::IntegerLiteral)
+ {
+ const auto v = getIntegerLiteralValue(operand.token);
+ // TODO: we should sign-extend these where appropriate,
+ // difficult because it requires information on usage...
+ return builder->emitSPIRVAsmOperandLiteral(
+ builder->getIntValue(builder->getUIntType(), v));
+ }
+ else if(operand.token.type == TokenType::StringLiteral)
+ {
+ const auto v = getStringLiteralTokenValue(operand.token);
+ return builder->emitSPIRVAsmOperandLiteral(
+ builder->getStringValue(v.getUnownedSlice()));
+ }
+ SLANG_UNREACHABLE("Unhandled literal type in visitSPIRVAsmExpr");
+ }
+ case SPIRVAsmOperand::Id:
+ {
+ const auto id = operand.token.getContent();
+ return builder->emitSPIRVAsmOperandId(
+ builder->getStringValue(id));
+ }
+ case SPIRVAsmOperand::NamedValue:
+ {
+ const auto id = operand.token.getContent();
+ return builder->emitSPIRVAsmOperandEnum(
+ builder->getStringValue(id));
+ }
+ case SPIRVAsmOperand::SlangValue:
+ {
+ IRInst* i;
+ {
+ IRBuilderInsertLocScope insertScope(builder);
+ builder->setInsertBefore(spirvAsmInst);
+ i = getSimpleVal(context, lowerRValueExpr(context, operand.expr));
+ }
+ return builder->emitSPIRVAsmOperandInst(i);
+ }
+ case SPIRVAsmOperand::SlangValueAddr:
+ {
+ IRInst* i;
+ {
+ IRBuilderInsertLocScope insertScope(builder);
+ builder->setInsertBefore(spirvAsmInst);
+ const auto addr = tryGetAddress(
+ context,
+ lowerLValueExpr(context, operand.expr),
+ TryGetAddressMode::Default
+ );
+ if(addr.flavor == LoweredValInfo::Flavor::Ptr)
+ i = addr.val;
+ else
+ {
+ context->getSink()->diagnose(operand.expr, Diagnostics::noSuchAddress);
+ return nullptr;
+ }
+ }
+ return builder->emitSPIRVAsmOperandInst(i);
+ }
+ case SPIRVAsmOperand::SlangType:
+ {
+ IRInst* i;
+ {
+ IRBuilderInsertLocScope insertScope(builder);
+ builder->setInsertBefore(spirvAsmInst);
+ i = lowerType(context, operand.type.type);
+ }
+ return builder->emitSPIRVAsmOperandInst(i);
+ }
+ }
+ SLANG_UNREACHABLE("Unhandled case in visitSPIRVAsmExpr");
+ };
+ IRBuilderInsertLocScope insertScope(builder);
+ builder->setInsertInto(spirvAsmInst);
+ for(const auto& inst : expr->insts)
+ {
+ const auto opcode = lowerOperand(inst.opcode);
+ List<IRInst*> operands;
+ for(const auto& operand : inst.operands)
+ operands.add(lowerOperand(operand));
+ builder->emitSPIRVAsmInst(opcode, operands);
+ }
+ return LoweredValInfo::simple(spirvAsmInst);
+ }
+
LoweredValInfo visitIndexExpr(IndexExpr* expr)
{
auto type = lowerType(context, expr->type);
diff --git a/source/slang/slang-parser.cpp b/source/slang/slang-parser.cpp
index fabf1ffb4..442ddbce6 100644
--- a/source/slang/slang-parser.cpp
+++ b/source/slang/slang-parser.cpp
@@ -2,6 +2,7 @@
#include <assert.h>
#include <float.h>
+#include <optional>
#include "slang-compiler.h"
#include "slang-lookup.h"
@@ -6147,6 +6148,120 @@ namespace Slang
}
}
+ static std::optional<SPIRVAsmOperand> parseSPIRVAsmOperand(Parser* parser)
+ {
+ const auto slangIdentOperand = [&](auto flavor){
+ const auto tok = parser->ReadToken(TokenType::Identifier);
+
+ VarExpr* varExpr = parser->astBuilder->create<VarExpr>();
+ varExpr->scope = parser->currentScope;
+ varExpr->loc = tok.getLoc();
+ varExpr->name = tok.getName();
+ return SPIRVAsmOperand{flavor, tok, varExpr};
+ };
+
+ // A regular identifier
+ if(parser->LookAheadToken(TokenType::Identifier))
+ {
+ return SPIRVAsmOperand{SPIRVAsmOperand::NamedValue, parser->ReadToken()};
+ }
+ // A literal integer or string
+ else if(parser->LookAheadToken(TokenType::IntegerLiteral)
+ || parser->LookAheadToken(TokenType::StringLiteral))
+ {
+ return SPIRVAsmOperand{SPIRVAsmOperand::Literal, parser->ReadToken()};
+ }
+ // A %foo id
+ else if(AdvanceIf(parser, TokenType::OpMod))
+ {
+ if(parser->LookAheadToken(TokenType::IntegerLiteral)
+ || parser->LookAheadToken(TokenType::Identifier))
+ {
+ return SPIRVAsmOperand{SPIRVAsmOperand::Id, parser->ReadToken()};
+ }
+ }
+ // A &foo variable reference (for the address of foo)
+ else if(AdvanceIf(parser, TokenType::OpBitAnd))
+ {
+ return slangIdentOperand(SPIRVAsmOperand::SlangValueAddr);
+ }
+ // A $foo variable
+ else if(AdvanceIf(parser, TokenType::Dollar))
+ {
+ return slangIdentOperand(SPIRVAsmOperand::SlangValue);
+ }
+ // A $$foo type
+ else if(AdvanceIf(parser, TokenType::DollarDollar))
+ {
+ return slangIdentOperand(SPIRVAsmOperand::SlangType);
+ }
+
+ Unexpected(parser);
+ return std::nullopt;
+ }
+
+ static std::optional<SPIRVAsmInst> parseSPIRVAsmInst(Parser* parser)
+ {
+ SPIRVAsmInst ret;
+
+ const auto resultOrOpcode = parseSPIRVAsmOperand(parser);
+ if(!resultOrOpcode)
+ return std::nullopt;
+
+ // We can enable this when we have a way of determining the index of
+ // the result id operand to each instruction, otherwise we don't know
+ // at which position in the operand list to insert this.
+#if 0
+ if(AdvanceIf(parser, TokenType::OpEql))
+ {
+ const auto opcode = parseSPIRVAsmOperand(parser);
+ if(!opcode)
+ return std::nullopt;
+ ret.opcode = *opcode;
+ ret.operands.insert(???, *resultOrOpcode);
+ }
+ else
+#endif
+ {
+ ret.opcode = *resultOrOpcode;
+ }
+
+ // TODO: diagnose wrong opcode flavor here
+
+ while(!(parser->LookAheadToken(TokenType::RBrace)
+ || parser->LookAheadToken(TokenType::Semicolon)))
+ {
+ if(const auto operand = parseSPIRVAsmOperand(parser))
+ ret.operands.add(*operand);
+ else
+ return std::nullopt;
+ }
+
+ return ret;
+ }
+
+ static Expr* parseSPIRVAsmExpr(Parser* parser)
+ {
+ SPIRVAsmExpr* asmExpr = parser->astBuilder->create<SPIRVAsmExpr>();
+
+ parser->ReadToken(TokenType::LBrace);
+ while(!parser->tokenReader.isAtEnd())
+ {
+ if(parser->LookAheadToken(TokenType::RBrace))
+ break;
+ if(const auto inst = parseSPIRVAsmInst(parser))
+ asmExpr->insts.add(*inst);
+ else
+ return nullptr;
+ if(parser->LookAheadToken(TokenType::RBrace))
+ break;
+ parser->ReadToken(TokenType::Semicolon);
+ }
+ parser->ReadToken(TokenType::RBrace);
+
+ return asmExpr;
+ }
+
static Expr* parsePrefixExpr(Parser* parser)
{
auto tokenType = peekTokenType(parser);
@@ -6179,7 +6294,10 @@ namespace Slang
}
return newExpr;
}
-
+ else if (AdvanceIf(parser, "spirv_asm"))
+ {
+ return parseSPIRVAsmExpr(parser);
+ }
return parsePostfixExpr(parser);
}
diff --git a/source/slang/slang-serialize-ast-type-info.h b/source/slang/slang-serialize-ast-type-info.h
index 5ccf9ea54..f5d636b01 100644
--- a/source/slang/slang-serialize-ast-type-info.h
+++ b/source/slang/slang-serialize-ast-type-info.h
@@ -237,6 +237,9 @@ struct SerialTypeInfo<RequirementWitness::Flavor> : public SerialConvertTypeInfo
// RequirementWitness
SLANG_VALUE_TYPE_INFO(RequirementWitness)
+// SPIRVAsm
+SLANG_VALUE_TYPE_INFO(SPIRVAsmOperand)
+SLANG_VALUE_TYPE_INFO(SPIRVAsmInst)
} // namespace Slang