summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--source/slang/emit.cpp43
-rw-r--r--source/slang/ir-inst-defs.h4
-rw-r--r--source/slang/ir.cpp65
-rw-r--r--source/slang/ir.h26
-rw-r--r--source/slang/lower-to-ir.cpp44
-rw-r--r--tests/bugs/gh-103.slang2
6 files changed, 177 insertions, 7 deletions
diff --git a/source/slang/emit.cpp b/source/slang/emit.cpp
index f6c63dff6..40366c43c 100644
--- a/source/slang/emit.cpp
+++ b/source/slang/emit.cpp
@@ -4036,6 +4036,10 @@ emitDeclImpl(decl, nullptr);
emitIRVectorType(context, (IRVectorType*) type);
break;
+ case kIROp_MatrixType:
+ emitIRMatrixType(context, (IRMatrixType*) type);
+ break;
+
case kIROp_StructType:
emit(getName(type));
break;
@@ -4106,6 +4110,18 @@ emitDeclImpl(decl, nullptr);
emitIRSimpleValue(context, type->getElementCount());
}
+ void emitIRMatrixType(
+ EmitContext* context,
+ IRMatrixType* type)
+ {
+ // TODO: this is a GLSL-vs-HLSL decision point
+
+ emitIRSimpleType(context, type->getElementType());
+ emitIRSimpleValue(context, type->getRowCount());
+ emit("x");
+ emitIRSimpleValue(context, type->getColumnCount());
+ }
+
void emitIRType(
EmitContext* context,
IRType* type,
@@ -4163,6 +4179,7 @@ emitDeclImpl(decl, nullptr);
case kIROp_IntLit:
case kIROp_FloatLit:
case kIROp_FieldAddress:
+ case kIROp_getElementPtr:
return true;
}
@@ -4321,12 +4338,20 @@ emitDeclImpl(decl, nullptr);
}
break;
- case kIROp_Add:
- emitIROperand(context, inst->getArg(1));
- emit(" + ");
- emitIROperand(context, inst->getArg(2));
- break;
+#define CASE(OPCODE, OP) \
+ case OPCODE: \
+ emitIROperand(context, inst->getArg(1)); \
+ emit("" #OP " "); \
+ emitIROperand(context, inst->getArg(2)); \
+ break
+ CASE(kIROp_Add, +);
+ CASE(kIROp_Sub, -);
+ CASE(kIROp_Mul, *);
+ CASE(kIROp_Div, /);
+ CASE(kIROp_Mod, %);
+
+#undef CASE
case kIROp_Sample:
emitIROperand(context, inst->getArg(1));
@@ -4383,6 +4408,14 @@ emitDeclImpl(decl, nullptr);
emit("]");
break;
+ case kIROp_getElement:
+ case kIROp_getElementPtr:
+ emitIROperand(context, inst->getArg(1));
+ emit("[");
+ emitIROperand(context, inst->getArg(2));
+ emit("]");
+ break;
+
default:
emit("/* uhandled */");
break;
diff --git a/source/slang/ir-inst-defs.h b/source/slang/ir-inst-defs.h
index 0ed66064a..dab7aa3d7 100644
--- a/source/slang/ir-inst-defs.h
+++ b/source/slang/ir-inst-defs.h
@@ -17,6 +17,7 @@ INST(TypeType, type.type, 0, 0)
INST(VoidType, type.void, 0, 0)
INST(BlockType, type.block, 0, 0)
INST(VectorType, type.vector, 2, 0)
+INST(MatrixType, matrixType, 3, 0)
INST(BoolType, type.bool, 0, 0)
INST(Float32Type, type.f32, 0, 0)
INST(Int32Type, type.i32, 0, 0)
@@ -53,6 +54,9 @@ INST(BufferStore, bufferStore, 3, 0)
INST(FieldExtract, get_field, 2, 0)
INST(FieldAddress, get_field_addr, 2, 0)
+INST(getElement, getElement, 2, 0)
+INST(getElementPtr, getElementPtr, 2, 0)
+
INST(ReturnVal, return_val, 1, 0)
INST(ReturnVoid, return_void, 1, 0)
diff --git a/source/slang/ir.cpp b/source/slang/ir.cpp
index 9e39579f6..3a6410125 100644
--- a/source/slang/ir.cpp
+++ b/source/slang/ir.cpp
@@ -534,6 +534,25 @@ namespace Slang
&args[0]);
}
+ template<typename T>
+ static T* findOrEmitInst(
+ IRBuilder* builder,
+ IROp op,
+ IRType* type,
+ IRInst* arg1,
+ IRInst* arg2,
+ IRInst* arg3)
+ {
+ IRInst* args[] = { arg1, arg2, arg3 };
+ return (T*) findOrEmitInstImpl(
+ builder,
+ sizeof(T),
+ op,
+ type,
+ 3,
+ &args[0]);
+ }
+
//
bool operator==(IRConstantKey const& left, IRConstantKey const& right)
@@ -644,6 +663,20 @@ namespace Slang
elementCount);
}
+ IRType* IRBuilder::getMatrixType(
+ IRType* elementType,
+ IRValue* rowCount,
+ IRValue* columnCount)
+ {
+ return findOrEmitInst<IRMatrixType>(
+ this,
+ kIROp_MatrixType,
+ getTypeType(),
+ elementType,
+ rowCount,
+ columnCount);
+ }
+
IRType* IRBuilder::getTypeType()
{
return findOrEmitInst<IRType>(
@@ -931,6 +964,38 @@ namespace Slang
return inst;
}
+ IRInst* IRBuilder::emitElementExtract(
+ IRType* type,
+ IRValue* base,
+ IRValue* index)
+ {
+ auto inst = createInst<IRFieldAddress>(
+ this,
+ kIROp_getElement,
+ type,
+ base,
+ index);
+
+ addInst(inst);
+ return inst;
+ }
+
+ IRInst* IRBuilder::emitElementAddress(
+ IRType* type,
+ IRValue* basePtr,
+ IRValue* index)
+ {
+ auto inst = createInst<IRFieldAddress>(
+ this,
+ kIROp_getElementPtr,
+ type,
+ basePtr,
+ index);
+
+ addInst(inst);
+ return inst;
+ }
+
IRInst* IRBuilder::emitReturn(
IRValue* val)
{
diff --git a/source/slang/ir.h b/source/slang/ir.h
index 1561d15a7..6e4a25fe1 100644
--- a/source/slang/ir.h
+++ b/source/slang/ir.h
@@ -275,6 +275,17 @@ struct IRVectorType : IRType
IRInst* getElementCount() { return elementCount.usedValue; }
};
+struct IRMatrixType : IRType
+{
+ IRUse elementType;
+ IRUse rowCount;
+ IRUse columnCount;
+
+ IRType* getElementType() { return (IRType*) elementType.usedValue; }
+ IRInst* getRowCount() { return rowCount.usedValue; }
+ IRInst* getColumnCount() { return columnCount.usedValue; }
+};
+
struct IRFuncType : IRType
{
IRUse resultType;
@@ -497,6 +508,10 @@ struct IRBuilder
IRType* getBaseType(BaseType flavor);
IRType* getBoolType();
IRType* getVectorType(IRType* elementType, IRValue* elementCount);
+ IRType* getMatrixType(
+ IRType* elementType,
+ IRValue* rowCount,
+ IRValue* columnCount);
IRType* getTypeType();
IRType* getVoidType();
IRType* getBlockType();
@@ -568,6 +583,17 @@ struct IRBuilder
IRValue* basePtr,
IRStructField* field);
+ IRInst* emitElementExtract(
+ IRType* type,
+ IRValue* base,
+ IRValue* index);
+
+ IRInst* emitElementAddress(
+ IRType* type,
+ IRValue* basePtr,
+ IRValue* index);
+
+
IRInst* emitReturn(
IRValue* val);
diff --git a/source/slang/lower-to-ir.cpp b/source/slang/lower-to-ir.cpp
index 4bdfa4438..10b4aefca 100644
--- a/source/slang/lower-to-ir.cpp
+++ b/source/slang/lower-to-ir.cpp
@@ -519,6 +519,15 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower
return getBuilder()->getVectorType(irElementType, irElementCount);
}
+
+ LoweredTypeInfo visitMatrixExpressionType(MatrixExpressionType* type)
+ {
+ auto irElementType = lowerSimpleType(context, type->getElementType());
+ auto irRowCount = lowerSimpleVal(context, type->getRowCount());
+ auto irColumnCount = lowerSimpleVal(context, type->getColumnCount());
+
+ return getBuilder()->getMatrixType(irElementType, irRowCount, irColumnCount);
+ }
};
LoweredValInfo lowerVal(
@@ -738,9 +747,42 @@ struct ExprLoweringVisitor : ExprVisitor<ExprLoweringVisitor, LoweredValInfo>
return emitCallToVal(context, type, funcVal, irArgs.Count(), irArgs.Buffer());
}
+ LoweredValInfo subscriptValue(
+ LoweredTypeInfo type,
+ LoweredValInfo baseVal,
+ IRValue* indexVal)
+ {
+ auto builder = getBuilder();
+ switch (baseVal.flavor)
+ {
+ case LoweredValInfo::Flavor::Simple:
+ return LoweredValInfo::simple(
+ builder->emitElementExtract(
+ getSimpleType(type),
+ getSimpleVal(context, baseVal),
+ indexVal));
+
+ case LoweredValInfo::Flavor::Ptr:
+ return LoweredValInfo::ptr(
+ builder->emitElementAddress(
+ builder->getPtrType(getSimpleType(type)),
+ baseVal.val,
+ indexVal));
+
+ default:
+ SLANG_UNIMPLEMENTED_X("subscript expr");
+ return LoweredValInfo();
+ }
+
+ }
+
LoweredValInfo visitIndexExpr(IndexExpr* expr)
{
- SLANG_UNIMPLEMENTED_X("codegen for subscript expression");
+ auto type = lowerType(context, expr->type);
+ auto baseVal = lowerExpr(context, expr->BaseExpression);
+ auto indexVal = getSimpleVal(context, lowerExpr(context, expr->IndexExpression));
+
+ return subscriptValue(type, baseVal, indexVal);
}
LoweredValInfo extractField(
diff --git a/tests/bugs/gh-103.slang b/tests/bugs/gh-103.slang
index 2b10c2b3f..c810afc05 100644
--- a/tests/bugs/gh-103.slang
+++ b/tests/bugs/gh-103.slang
@@ -1,4 +1,4 @@
-//TEST:COMPARE_HLSL: -profile ps_4_0 -entry main
+//TEST:COMPARE_HLSL: -use-ir -profile ps_4_0 -entry main
// Ensure that matrix-times-scalar works