diff options
| -rw-r--r-- | source/slang/emit.cpp | 43 | ||||
| -rw-r--r-- | source/slang/ir-inst-defs.h | 4 | ||||
| -rw-r--r-- | source/slang/ir.cpp | 65 | ||||
| -rw-r--r-- | source/slang/ir.h | 26 | ||||
| -rw-r--r-- | source/slang/lower-to-ir.cpp | 44 | ||||
| -rw-r--r-- | tests/bugs/gh-103.slang | 2 |
6 files changed, 177 insertions, 7 deletions
diff --git a/source/slang/emit.cpp b/source/slang/emit.cpp index f6c63dff6..40366c43c 100644 --- a/source/slang/emit.cpp +++ b/source/slang/emit.cpp @@ -4036,6 +4036,10 @@ emitDeclImpl(decl, nullptr); emitIRVectorType(context, (IRVectorType*) type); break; + case kIROp_MatrixType: + emitIRMatrixType(context, (IRMatrixType*) type); + break; + case kIROp_StructType: emit(getName(type)); break; @@ -4106,6 +4110,18 @@ emitDeclImpl(decl, nullptr); emitIRSimpleValue(context, type->getElementCount()); } + void emitIRMatrixType( + EmitContext* context, + IRMatrixType* type) + { + // TODO: this is a GLSL-vs-HLSL decision point + + emitIRSimpleType(context, type->getElementType()); + emitIRSimpleValue(context, type->getRowCount()); + emit("x"); + emitIRSimpleValue(context, type->getColumnCount()); + } + void emitIRType( EmitContext* context, IRType* type, @@ -4163,6 +4179,7 @@ emitDeclImpl(decl, nullptr); case kIROp_IntLit: case kIROp_FloatLit: case kIROp_FieldAddress: + case kIROp_getElementPtr: return true; } @@ -4321,12 +4338,20 @@ emitDeclImpl(decl, nullptr); } break; - case kIROp_Add: - emitIROperand(context, inst->getArg(1)); - emit(" + "); - emitIROperand(context, inst->getArg(2)); - break; +#define CASE(OPCODE, OP) \ + case OPCODE: \ + emitIROperand(context, inst->getArg(1)); \ + emit("" #OP " "); \ + emitIROperand(context, inst->getArg(2)); \ + break + CASE(kIROp_Add, +); + CASE(kIROp_Sub, -); + CASE(kIROp_Mul, *); + CASE(kIROp_Div, /); + CASE(kIROp_Mod, %); + +#undef CASE case kIROp_Sample: emitIROperand(context, inst->getArg(1)); @@ -4383,6 +4408,14 @@ emitDeclImpl(decl, nullptr); emit("]"); break; + case kIROp_getElement: + case kIROp_getElementPtr: + emitIROperand(context, inst->getArg(1)); + emit("["); + emitIROperand(context, inst->getArg(2)); + emit("]"); + break; + default: emit("/* uhandled */"); break; diff --git a/source/slang/ir-inst-defs.h b/source/slang/ir-inst-defs.h index 0ed66064a..dab7aa3d7 100644 --- a/source/slang/ir-inst-defs.h +++ b/source/slang/ir-inst-defs.h @@ -17,6 +17,7 @@ INST(TypeType, type.type, 0, 0) INST(VoidType, type.void, 0, 0) INST(BlockType, type.block, 0, 0) INST(VectorType, type.vector, 2, 0) +INST(MatrixType, matrixType, 3, 0) INST(BoolType, type.bool, 0, 0) INST(Float32Type, type.f32, 0, 0) INST(Int32Type, type.i32, 0, 0) @@ -53,6 +54,9 @@ INST(BufferStore, bufferStore, 3, 0) INST(FieldExtract, get_field, 2, 0) INST(FieldAddress, get_field_addr, 2, 0) +INST(getElement, getElement, 2, 0) +INST(getElementPtr, getElementPtr, 2, 0) + INST(ReturnVal, return_val, 1, 0) INST(ReturnVoid, return_void, 1, 0) diff --git a/source/slang/ir.cpp b/source/slang/ir.cpp index 9e39579f6..3a6410125 100644 --- a/source/slang/ir.cpp +++ b/source/slang/ir.cpp @@ -534,6 +534,25 @@ namespace Slang &args[0]); } + template<typename T> + static T* findOrEmitInst( + IRBuilder* builder, + IROp op, + IRType* type, + IRInst* arg1, + IRInst* arg2, + IRInst* arg3) + { + IRInst* args[] = { arg1, arg2, arg3 }; + return (T*) findOrEmitInstImpl( + builder, + sizeof(T), + op, + type, + 3, + &args[0]); + } + // bool operator==(IRConstantKey const& left, IRConstantKey const& right) @@ -644,6 +663,20 @@ namespace Slang elementCount); } + IRType* IRBuilder::getMatrixType( + IRType* elementType, + IRValue* rowCount, + IRValue* columnCount) + { + return findOrEmitInst<IRMatrixType>( + this, + kIROp_MatrixType, + getTypeType(), + elementType, + rowCount, + columnCount); + } + IRType* IRBuilder::getTypeType() { return findOrEmitInst<IRType>( @@ -931,6 +964,38 @@ namespace Slang return inst; } + IRInst* IRBuilder::emitElementExtract( + IRType* type, + IRValue* base, + IRValue* index) + { + auto inst = createInst<IRFieldAddress>( + this, + kIROp_getElement, + type, + base, + index); + + addInst(inst); + return inst; + } + + IRInst* IRBuilder::emitElementAddress( + IRType* type, + IRValue* basePtr, + IRValue* index) + { + auto inst = createInst<IRFieldAddress>( + this, + kIROp_getElementPtr, + type, + basePtr, + index); + + addInst(inst); + return inst; + } + IRInst* IRBuilder::emitReturn( IRValue* val) { diff --git a/source/slang/ir.h b/source/slang/ir.h index 1561d15a7..6e4a25fe1 100644 --- a/source/slang/ir.h +++ b/source/slang/ir.h @@ -275,6 +275,17 @@ struct IRVectorType : IRType IRInst* getElementCount() { return elementCount.usedValue; } }; +struct IRMatrixType : IRType +{ + IRUse elementType; + IRUse rowCount; + IRUse columnCount; + + IRType* getElementType() { return (IRType*) elementType.usedValue; } + IRInst* getRowCount() { return rowCount.usedValue; } + IRInst* getColumnCount() { return columnCount.usedValue; } +}; + struct IRFuncType : IRType { IRUse resultType; @@ -497,6 +508,10 @@ struct IRBuilder IRType* getBaseType(BaseType flavor); IRType* getBoolType(); IRType* getVectorType(IRType* elementType, IRValue* elementCount); + IRType* getMatrixType( + IRType* elementType, + IRValue* rowCount, + IRValue* columnCount); IRType* getTypeType(); IRType* getVoidType(); IRType* getBlockType(); @@ -568,6 +583,17 @@ struct IRBuilder IRValue* basePtr, IRStructField* field); + IRInst* emitElementExtract( + IRType* type, + IRValue* base, + IRValue* index); + + IRInst* emitElementAddress( + IRType* type, + IRValue* basePtr, + IRValue* index); + + IRInst* emitReturn( IRValue* val); diff --git a/source/slang/lower-to-ir.cpp b/source/slang/lower-to-ir.cpp index 4bdfa4438..10b4aefca 100644 --- a/source/slang/lower-to-ir.cpp +++ b/source/slang/lower-to-ir.cpp @@ -519,6 +519,15 @@ struct ValLoweringVisitor : ValVisitor<ValLoweringVisitor, LoweredValInfo, Lower return getBuilder()->getVectorType(irElementType, irElementCount); } + + LoweredTypeInfo visitMatrixExpressionType(MatrixExpressionType* type) + { + auto irElementType = lowerSimpleType(context, type->getElementType()); + auto irRowCount = lowerSimpleVal(context, type->getRowCount()); + auto irColumnCount = lowerSimpleVal(context, type->getColumnCount()); + + return getBuilder()->getMatrixType(irElementType, irRowCount, irColumnCount); + } }; LoweredValInfo lowerVal( @@ -738,9 +747,42 @@ struct ExprLoweringVisitor : ExprVisitor<ExprLoweringVisitor, LoweredValInfo> return emitCallToVal(context, type, funcVal, irArgs.Count(), irArgs.Buffer()); } + LoweredValInfo subscriptValue( + LoweredTypeInfo type, + LoweredValInfo baseVal, + IRValue* indexVal) + { + auto builder = getBuilder(); + switch (baseVal.flavor) + { + case LoweredValInfo::Flavor::Simple: + return LoweredValInfo::simple( + builder->emitElementExtract( + getSimpleType(type), + getSimpleVal(context, baseVal), + indexVal)); + + case LoweredValInfo::Flavor::Ptr: + return LoweredValInfo::ptr( + builder->emitElementAddress( + builder->getPtrType(getSimpleType(type)), + baseVal.val, + indexVal)); + + default: + SLANG_UNIMPLEMENTED_X("subscript expr"); + return LoweredValInfo(); + } + + } + LoweredValInfo visitIndexExpr(IndexExpr* expr) { - SLANG_UNIMPLEMENTED_X("codegen for subscript expression"); + auto type = lowerType(context, expr->type); + auto baseVal = lowerExpr(context, expr->BaseExpression); + auto indexVal = getSimpleVal(context, lowerExpr(context, expr->IndexExpression)); + + return subscriptValue(type, baseVal, indexVal); } LoweredValInfo extractField( diff --git a/tests/bugs/gh-103.slang b/tests/bugs/gh-103.slang index 2b10c2b3f..c810afc05 100644 --- a/tests/bugs/gh-103.slang +++ b/tests/bugs/gh-103.slang @@ -1,4 +1,4 @@ -//TEST:COMPARE_HLSL: -profile ps_4_0 -entry main +//TEST:COMPARE_HLSL: -use-ir -profile ps_4_0 -entry main // Ensure that matrix-times-scalar works |
