diff options
| author | Tim Foley <tfoleyNV@users.noreply.github.com> | 2017-12-20 17:35:10 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2017-12-20 17:35:10 -0800 |
| commit | 6f681279d99e72e717bb2b91763b80e570ae725b (patch) | |
| tree | 501c547ff405aa5227a0ad165b9ec371fcd94ef8 /source | |
| parent | 35318fb2b08c82f80cbd464e93d81ebe719c40be (diff) | |
IR: fixes for subscript accessors (#322)
* IR: fixes for subscript accessors
Fixes #320
This is a bunch of fixes for handling of `__subscript` operations on builtin types (notably `RWStructuredBuffer` and `StructuredBuffer` at this point).
- Automatically add a `GetterDecl` to any subscript decalratio was declithout any accessors. This avoids hitting a null- dereference in the emit logic.
- Add a notion of a `RefAccessor` (declared with `ref`) as a peer to getters and setters. The idea is that a `ref` accessor returns a pointer to the element data, so that it can be used for both getting and setting values. This is closer to the behavior of `RWStructuredBuffer` element access in HLSL.
- Fixes for dealing with "access chains" where there might be a combination of a subscript (where the is a `get` and `set` but no `ref`) and member access, so that we have to read the base value into a temp, modify it, and then write it back.
- This logic is still a bit of a mess, so we will eventually want to take a more consistent pass over this to deal with how we "materialize" values for setters.
- Update `RWStructuredBuffer` to have a `ref` accessor, and then fix up the IR tests to handle the new opcode that I added for it.
- Note: I didn't handle this as an intrinsic simply because the `tests/ir/*` tests aren't really set up to handle builtins with ugly mangled names.
* Fixup: type error in VM for buffer element ref
I was using the result type of the op as the element type for computing the element address, but the result type is a pointer to the real element type.
This caused test failures on 64-bit platforms, where the stride of the buffer in the `ir/factorial` test needs to be 4.
The fix is to assume the result type is a pointer, and extract the pointed-to type out of that.
Diffstat (limited to 'source')
| -rw-r--r-- | source/slang/check.cpp | 30 | ||||
| -rw-r--r-- | source/slang/decl-defs.h | 1 | ||||
| -rw-r--r-- | source/slang/emit.cpp | 2 | ||||
| -rw-r--r-- | source/slang/hlsl.meta.slang | 10 | ||||
| -rw-r--r-- | source/slang/hlsl.meta.slang.h | 10 | ||||
| -rw-r--r-- | source/slang/ir-inst-defs.h | 1 | ||||
| -rw-r--r-- | source/slang/lower-to-ir.cpp | 263 | ||||
| -rw-r--r-- | source/slang/parser.cpp | 4 | ||||
| -rw-r--r-- | source/slang/vm.cpp | 27 |
9 files changed, 292 insertions, 56 deletions
diff --git a/source/slang/check.cpp b/source/slang/check.cpp index 4840ae30d..b981fb778 100644 --- a/source/slang/check.cpp +++ b/source/slang/check.cpp @@ -3575,6 +3575,32 @@ namespace Slang decl->SetCheckState(DeclCheckState::CheckedHeader); + // If we have a subscript declaration with no accessor declarations, + // then we should create a single `GetterDecl` to represent + // the implicit meaning of their declaration, so: + // + // subscript(uint index) -> T; + // + // becomes: + // + // subscript(uint index) -> T { get; } + // + + bool anyAccessors = false; + for(auto accessorDecl : decl->getMembersOfType<AccessorDecl>()) + { + anyAccessors = true; + } + + if(!anyAccessors) + { + RefPtr<GetterDecl> getterDecl = new GetterDecl(); + getterDecl->loc = decl->loc; + + getterDecl->ParentDecl = decl; + decl->Members.Add(getterDecl); + } + for(auto mm : decl->Members) { checkDecl(mm); @@ -4662,6 +4688,10 @@ namespace Slang { callExpr->type.IsLeftValue = true; } + for(auto refAccessor : subscriptDeclRef.getDecl()->getMembersOfType<RefAccessorDecl>()) + { + callExpr->type.IsLeftValue = true; + } } // TODO: there may be other cases that confer l-value-ness diff --git a/source/slang/decl-defs.h b/source/slang/decl-defs.h index e24a535c5..7cb9ffc0f 100644 --- a/source/slang/decl-defs.h +++ b/source/slang/decl-defs.h @@ -179,6 +179,7 @@ SIMPLE_SYNTAX_CLASS(AccessorDecl, FunctionDeclBase) SIMPLE_SYNTAX_CLASS(GetterDecl, AccessorDecl) SIMPLE_SYNTAX_CLASS(SetterDecl, AccessorDecl) +SIMPLE_SYNTAX_CLASS(RefAccessorDecl, AccessorDecl) SIMPLE_SYNTAX_CLASS(FuncDecl, FunctionDeclBase) diff --git a/source/slang/emit.cpp b/source/slang/emit.cpp index 311f903cc..5cd0ea832 100644 --- a/source/slang/emit.cpp +++ b/source/slang/emit.cpp @@ -4875,6 +4875,7 @@ emitDeclImpl(decl, nullptr); case kIROp_FieldAddress: case kIROp_getElementPtr: case kIROp_specialize: + case kIROp_BufferElementRef: return true; } @@ -5516,6 +5517,7 @@ emitDeclImpl(decl, nullptr); break; case kIROp_BufferLoad: + case kIROp_BufferElementRef: emitIROperand(ctx, inst->getArg(0)); emit("["); emitIROperand(ctx, inst->getArg(1)); diff --git a/source/slang/hlsl.meta.slang b/source/slang/hlsl.meta.slang index cdf720006..09cf731b5 100644 --- a/source/slang/hlsl.meta.slang +++ b/source/slang/hlsl.meta.slang @@ -44,8 +44,7 @@ struct StructuredBuffer T Load(int location); T Load(int location, out uint status); - __intrinsic_op(bufferLoad) - __subscript(uint index) -> T; + __subscript(uint index) -> T { __intrinsic_op(bufferLoad) get; }; }; __generic<T> __magic_type(HLSLConsumeStructuredBufferType) struct ConsumeStructuredBuffer @@ -201,11 +200,8 @@ struct RWStructuredBuffer __subscript(uint index) -> T { - __intrinsic_op(bufferLoad) - get; - - __intrinsic_op(bufferStore) - set; + __intrinsic_op(bufferElementRef) + ref; } }; diff --git a/source/slang/hlsl.meta.slang.h b/source/slang/hlsl.meta.slang.h index c9ccfcc81..acaef8fbd 100644 --- a/source/slang/hlsl.meta.slang.h +++ b/source/slang/hlsl.meta.slang.h @@ -45,8 +45,7 @@ sb << "\n"; sb << " T Load(int location);\n"; sb << " T Load(int location, out uint status);\n"; sb << "\n"; -sb << " __intrinsic_op(bufferLoad)\n"; -sb << " __subscript(uint index) -> T;\n"; +sb << " __subscript(uint index) -> T { __intrinsic_op(bufferLoad) get; };\n"; sb << "};\n"; sb << "\n"; sb << "__generic<T> __magic_type(HLSLConsumeStructuredBufferType) struct ConsumeStructuredBuffer\n"; @@ -203,11 +202,8 @@ sb << " T Load(int location, out uint status);\n"; sb << "\n"; sb << "\t__subscript(uint index) -> T\n"; sb << "\t{\n"; -sb << "\t\t__intrinsic_op(bufferLoad)\n"; -sb << "\t\tget;\n"; -sb << "\n"; -sb << "\t\t__intrinsic_op(bufferStore)\n"; -sb << "\t\tset;\n"; +sb << " __intrinsic_op(bufferElementRef)\n"; +sb << " ref;\n"; sb << "\t}\n"; sb << "};\n"; sb << "\n"; diff --git a/source/slang/ir-inst-defs.h b/source/slang/ir-inst-defs.h index ffad04467..7eafe89f7 100644 --- a/source/slang/ir-inst-defs.h +++ b/source/slang/ir-inst-defs.h @@ -119,6 +119,7 @@ INST(Store, store, 2, 0) INST(BufferLoad, bufferLoad, 2, 0) INST(BufferStore, bufferStore, 3, 0) +INST(BufferElementRef, bufferElementRef, 2, 0) INST(FieldExtract, get_field, 2, 0) INST(FieldAddress, get_field_addr, 2, 0) diff --git a/source/slang/lower-to-ir.cpp b/source/slang/lower-to-ir.cpp index 8c8b34908..8fc1239cd 100644 --- a/source/slang/lower-to-ir.cpp +++ b/source/slang/lower-to-ir.cpp @@ -207,6 +207,9 @@ struct BoundMemberInfo : ExtendedValueInfo // The (AST-level) declaration reference. DeclRef<Decl> declRef; + + // The type of this value + RefPtr<Type> type; }; // Represents the result of a swizzle operation in @@ -362,6 +365,7 @@ LoweredValInfo emitCallToVal( switch (funcVal.flavor) { case LoweredValInfo::Flavor::None: + SLANG_UNEXPECTED("null function"); default: return LoweredValInfo::simple( builder->emitCallInst(type, getSimpleVal(context, funcVal), argCount, args)); @@ -531,25 +535,60 @@ LoweredValInfo emitCallToDeclRef( if (auto subscriptDeclRef = funcDeclRef.As<SubscriptDecl>()) { - // A reference to a subscript declaration is potentially a - // special case, if we have more than just a getter. + // A reference to a subscript declaration is a special case, + // because it is not possible to call a subscript directly; + // we must call one of its accessors. + // + // TODO: everything here will also apply to propery declarations + // once we have them, so some of this code might be shared + // some day. DeclRef<GetterDecl> getterDeclRef; bool justAGetter = true; for (auto accessorDeclRef : getMembersOfType<AccessorDecl>(subscriptDeclRef)) { + // If the subscript declares a `ref` accessor, then we can just + // invoke that directly to get an l-value we can use. + if(auto refAccessorDeclRef = accessorDeclRef.As<RefAccessorDecl>()) + { + // The `ref` accessor will return a pointer to the value, so + // we need to reflect that in the type of our `call` instruction. + RefPtr<Type> ptrType = context->getSession()->getPtrType(type); + + // Rather than call `emitCallToVal` here, we make a recursive call + // to `emitCallToDeclRef` so that it can handle things like intrinsic-op + // modifiers attached to the acecssor. + LoweredValInfo callVal = emitCallToDeclRef( + context, + ptrType, + refAccessorDeclRef, + funcExpr, + argCount, + args); + + // The result from the call needs to be implicitly dereferenced, + // so that it can work as an l-value of the desired result type. + return LoweredValInfo::ptr(getSimpleVal(context, callVal)); + } + + // If we don't find a `ref` accessor, then we want to track whether + // this subscript has any accessors other than `get` (assuming + // that everything except `get` can be used for setting...). + if (auto foundGetterDeclRef = accessorDeclRef.As<GetterDecl>()) { + // We found a getter. getterDeclRef = foundGetterDeclRef; } else { + // There was something other than a getter, so we can't + // invoke an accessor just now. justAGetter = false; - break; } } - if (!justAGetter) + if (!justAGetter || !getterDeclRef) { // We can't perform an actual call right now, because // this expression might appear in an r-value or l-value @@ -575,8 +614,7 @@ LoweredValInfo emitCallToDeclRef( // Otherwise we are just call the getter, and so that // is what we need to be emitting a call to... - if (getterDeclRef) - funcDeclRef = getterDeclRef; + funcDeclRef = getterDeclRef; } auto funcDecl = funcDeclRef.getDecl(); @@ -660,7 +698,64 @@ LoweredValInfo emitCallToDeclRef( return emitCallToDeclRef(context, type, funcDeclRef, funcExpr, args.Count(), args.Buffer()); } -IRValue* getSimpleVal(IRGenContext* context, LoweredValInfo lowered) +LoweredValInfo extractField( + IRGenContext* context, + Type* fieldType, + LoweredValInfo base, + DeclRef<StructField> field) +{ + IRBuilder* builder = context->irBuilder; + + switch (base.flavor) + { + default: + { + IRValue* irBase = getSimpleVal(context, base); + return LoweredValInfo::simple( + builder->emitFieldExtract( + fieldType, + irBase, + builder->getDeclRefVal(field))); + } + break; + + case LoweredValInfo::Flavor::BoundMember: + case LoweredValInfo::Flavor::BoundSubscript: + { + // The base value is one that is trying to defer a get-vs-set + // decision, so we will need to do the same. + + RefPtr<BoundMemberInfo> boundMemberInfo = new BoundMemberInfo(); + boundMemberInfo->type = fieldType; + boundMemberInfo->base = base; + boundMemberInfo->declRef = field; + + context->shared->extValues.Add(boundMemberInfo); + return LoweredValInfo::boundMember(boundMemberInfo); + } + break; + + case LoweredValInfo::Flavor::Ptr: + { + // We are "extracting" a field from an lvalue address, + // which means we should just compute an lvalue + // representing the field address. + IRValue* irBasePtr = base.val; + return LoweredValInfo::ptr( + builder->emitFieldAddress( + context->getSession()->getPtrType(fieldType), + irBasePtr, + builder->getDeclRefVal(field))); + } + break; + } +} + + + +LoweredValInfo materialize( + IRGenContext* context, + LoweredValInfo lowered) { auto builder = context->irBuilder; @@ -668,13 +763,9 @@ top: switch(lowered.flavor) { case LoweredValInfo::Flavor::None: - return nullptr; - case LoweredValInfo::Flavor::Simple: - return lowered.val; - case LoweredValInfo::Flavor::Ptr: - return builder->emitLoad(lowered.val); + return lowered; case LoweredValInfo::Flavor::BoundSubscript: { @@ -693,7 +784,27 @@ top: } SLANG_UNEXPECTED("subscript had no getter"); - UNREACHABLE_RETURN(nullptr); + UNREACHABLE_RETURN(LoweredValInfo()); + } + break; + + case LoweredValInfo::Flavor::BoundMember: + { + auto boundMemberInfo = lowered.getBoundMemberInfo(); + auto base = materialize(context, boundMemberInfo->base); + + auto declRef = boundMemberInfo->declRef; + if( auto fieldDeclRef = declRef.As<StructField>() ) + { + lowered = extractField(context, boundMemberInfo->type, base, fieldDeclRef); + goto top; + } + else + { + + SLANG_UNEXPECTED("unexpected member flavor"); + UNREACHABLE_RETURN(LoweredValInfo()); + } } break; @@ -701,15 +812,41 @@ top: { auto swizzleInfo = lowered.getSwizzledLValueInfo(); - return builder->emitSwizzle( + return LoweredValInfo::simple(builder->emitSwizzle( swizzleInfo->type, getSimpleVal(context, swizzleInfo->base), swizzleInfo->elementCount, - swizzleInfo->elementIndices); + swizzleInfo->elementIndices)); } default: SLANG_UNEXPECTED("unhandled value flavor"); + UNREACHABLE_RETURN(LoweredValInfo()); + } + +} + +IRValue* getSimpleVal(IRGenContext* context, LoweredValInfo lowered) +{ + auto builder = context->irBuilder; + + // First, try to eliminate any "bound" operations along the chain, + // so that we are dealing with an ordinary value, or an l-value pointer. + lowered = materialize(context, lowered); + + switch(lowered.flavor) + { + case LoweredValInfo::Flavor::None: + return nullptr; + + case LoweredValInfo::Flavor::Simple: + return lowered.val; + + case LoweredValInfo::Flavor::Ptr: + return builder->emitLoad(lowered.val); + + default: + SLANG_UNEXPECTED("unhandled value flavor"); UNREACHABLE_RETURN(nullptr); } } @@ -1042,6 +1179,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> else if (auto callableDeclRef = declRef.As<CallableDecl>()) { RefPtr<BoundMemberInfo> boundMemberInfo = new BoundMemberInfo(); + boundMemberInfo->type = nullptr; boundMemberInfo->base = loweredBase; boundMemberInfo->declRef = callableDeclRef; return LoweredValInfo::boundMember(boundMemberInfo); @@ -1543,33 +1681,7 @@ struct ExprLoweringVisitorBase : ExprVisitor<Derived, LoweredValInfo> LoweredValInfo base, DeclRef<StructField> field) { - switch (base.flavor) - { - default: - { - IRValue* irBase = getSimpleVal(context, base); - return LoweredValInfo::simple( - getBuilder()->emitFieldExtract( - getSimpleType(fieldType), - irBase, - getBuilder()->getDeclRefVal(field))); - } - break; - - case LoweredValInfo::Flavor::Ptr: - { - // We are "extracting" a field from an lvalue address, - // which means we should just compute an lvalue - // representing the field address. - IRValue* irBasePtr = base.val; - return LoweredValInfo::ptr( - getBuilder()->emitFieldAddress( - context->getSession()->getPtrType(getSimpleType(fieldType)), - irBasePtr, - getBuilder()->getDeclRefVal(field))); - } - break; - } + return Slang::extractField(context, getSimpleType(fieldType), base, field); } LoweredValInfo visitStaticMemberExpr(StaticMemberExpr* expr) @@ -2392,6 +2504,28 @@ void lowerStmt( return visitor.dispatch(stmt); } +static LoweredValInfo maybeMoveMutableTemp( + IRGenContext* context, + LoweredValInfo const& val) +{ + switch(val.flavor) + { + case LoweredValInfo::Flavor::Ptr: + return val; + + default: + { + IRValue* irVal = getSimpleVal(context, val); + auto type = irVal->getType(); + auto var = createVar(context, type); + + assign(context, var, LoweredValInfo::simple(irVal)); + return var; + } + break; + } +} + void assign( IRGenContext* context, LoweredValInfo const& inLeft, @@ -2412,6 +2546,7 @@ top: case LoweredValInfo::Flavor::Ptr: case LoweredValInfo::Flavor::SwizzledLValue: case LoweredValInfo::Flavor::BoundSubscript: + case LoweredValInfo::Flavor::BoundMember: { builder->emitStore( left.val, @@ -2511,6 +2646,43 @@ top: } break; + case LoweredValInfo::Flavor::BoundMember: + { + auto boundMemberInfo = left.getBoundMemberInfo(); + + // If we hit this case, then it means that we are trying to set + // a single field in someting that is not atomically set-able. + // (e.g., an element of a value where the `subscript` operation + // has `get` and `set` but not a `ref` accessor). + // + // We need to read the entire base value out, modify the field + // we care about, and then write it back. + + auto declRef = boundMemberInfo->declRef; + if( auto fieldDeclRef = declRef.As<StructField>() ) + { + // materialize the base value and move it into + // a mutable temporary if needed + auto baseVal = boundMemberInfo->base; + auto tempVal = maybeMoveMutableTemp(context, materialize(context, baseVal)); + + // extract the field l-value out of the temporary + auto tempFieldVal = extractField(context, boundMemberInfo->type, tempVal, fieldDeclRef); + + // assign to the field of the temporary l-value + assign(context, tempFieldVal, right); + + // write back the modified temporary to the base l-value + assign(context, baseVal, tempVal); + } + else + { + SLANG_UNEXPECTED("handled member flavor"); + } + + } + break; + default: SLANG_UNIMPLEMENTED_X("assignment"); break; @@ -3303,6 +3475,13 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo> irResultType = context->getSession()->getVoidType(); } + if( auto refAccessorDecl = dynamic_cast<RefAccessorDecl*>(decl) ) + { + // A `ref` accessor needs to return a *pointer* to the value + // being accessed, rather than a simple value. + irResultType = context->getSession()->getPtrType(irResultType); + } + auto irFuncType = getFuncType( context, paramTypes.Count(), diff --git a/source/slang/parser.cpp b/source/slang/parser.cpp index 3c63c9b56..35cc96b5c 100644 --- a/source/slang/parser.cpp +++ b/source/slang/parser.cpp @@ -2313,6 +2313,10 @@ namespace Slang { decl = new SetterDecl(); } + else if( AdvanceIf(parser, "ref") ) + { + decl = new RefAccessorDecl(); + } else { Unexpected(parser); diff --git a/source/slang/vm.cpp b/source/slang/vm.cpp index d795a841b..ffc455232 100644 --- a/source/slang/vm.cpp +++ b/source/slang/vm.cpp @@ -877,6 +877,33 @@ void resumeThread( memcpy(elementData, srcPtr, size); } break; + + case kIROp_BufferElementRef: + { + VMType ptrType = decodeType(frame, &ip); + VMType type = ((VMPtrTypeImpl*)ptrType.getImpl())->base; + + UInt argCount = decodeUInt(&ip); + void* argPtrs[16] = { 0 }; + for( UInt aa = 0; aa < argCount; ++aa ) + { + void* argPtr = decodeOperandPtr<void>(frame, &ip); + argPtrs[aa] = argPtr; + } + + void* dest = decodeOperandPtr<void>(frame, &ip); + + char* bufferData = *(char**)argPtrs[0]; + uint32_t index = *(uint32_t*)argPtrs[1]; + + auto size = type.getSize(); + char* elementData = bufferData + index*size; + + *(void**)dest = elementData; + } + break; + + case kIROp_Call: { VMType type = decodeType(frame, &ip); |
