summaryrefslogtreecommitdiffstats
path: root/Whisper/ML/Reshaper.cpp
diff options
context:
space:
mode:
authorKonstantin <const@const.me>2023-01-16 14:52:43 +0100
committerKonstantin <const@const.me>2023-01-16 14:52:43 +0100
commit8c4603c73675958efc960fbd4bb599a2909d106a (patch)
tree714dc6fc9a1672d5fd7f89676b97e10959662abc /Whisper/ML/Reshaper.cpp
parent990a8d0dbaefc996244097397259e92758b15cce (diff)
Source codes
Diffstat (limited to 'Whisper/ML/Reshaper.cpp')
-rw-r--r--Whisper/ML/Reshaper.cpp80
1 files changed, 80 insertions, 0 deletions
diff --git a/Whisper/ML/Reshaper.cpp b/Whisper/ML/Reshaper.cpp
new file mode 100644
index 0000000..af66929
--- /dev/null
+++ b/Whisper/ML/Reshaper.cpp
@@ -0,0 +1,80 @@
+#include "stdafx.h"
+#include "Reshaper.h"
+#include "../D3D/MappedResource.h"
+#include "../D3D/Binder.h"
+#include "../D3D/shaders.h"
+#include "reshapedMultiply.h"
+
+namespace
+{
+ using namespace DirectCompute;
+ struct Constants
+ {
+ // Size and strides of the source tensor
+ TensorShape arg0;
+ uint32_t zzPadding;
+ // Count of elements per panel
+ uint32_t panelSize;
+ // Layer strides of the output matrix
+ std::array<uint32_t, 2> layerStrides;
+ };
+}
+
+HRESULT DirectCompute::Reshaper::createConstants()
+{
+ constexpr uint32_t cb = sizeof( Constants );
+ CD3D11_BUFFER_DESC desc{ cb, D3D11_BIND_CONSTANT_BUFFER, D3D11_USAGE_DYNAMIC, D3D11_CPU_ACCESS_WRITE };
+ return device()->CreateBuffer( &desc, nullptr, &constantBuffer );
+}
+
+HRESULT DirectCompute::Reshaper::makePanels( Tensor& tensor, eDataType dataType )
+{
+ if( !constantBuffer )
+ CHECK( createConstants() );
+
+ constexpr uint32_t TILE_SIZE = ReshapedMultiply::TILE_SIZE;
+
+ // Reshaping into column major horizontal panels, height = TILE_SIZE, width = width of the source matrix
+
+ std::array<uint32_t, 4> ne = tensor.ne;
+ const uint32_t groupsX = ( ne[ 1 ] + TILE_SIZE - 1 ) / TILE_SIZE;
+ ne[ 1 ] = groupsX * TILE_SIZE;;
+ // Each panel has [ size.x, TILE_SIZE ] elements
+ const uint32_t panelSize = ne[ 0 ] * TILE_SIZE;
+
+ Tensor result;
+ result.create( dataType, ne );
+
+ {
+ MappedResource mapped;
+ CHECK( mapped.map( constantBuffer, false ) );
+ Constants& cb = *(Constants*)mapped.data();
+
+ store( cb.arg0.ne, tensor.sizeVec() );
+ store( cb.arg0.nb, tensor.stridesVec() );
+ cb.panelSize = panelSize;
+ cb.layerStrides[ 0 ] = result.nb[ 2 ];
+ cb.layerStrides[ 1 ] = result.nb[ 3 ];
+ }
+
+ csSetCB( constantBuffer );
+ {
+ Binder bind;
+ bind.bind( tensor, result );
+ bindShader( eComputeShader::matReshapePanels );
+ context()->Dispatch( groupsX, tensor.ne[ 2 ], tensor.ne[ 3 ] );
+ }
+
+ tensor.nb[ 0 ] = 0;
+ tensor.nb[ 1 ] = panelSize;
+ tensor.nb[ 2 ] = result.nb[ 2 ];
+ tensor.nb[ 3 ] = result.nb[ 3 ];
+ tensor.setGpuViews( result );
+ return S_OK;
+}
+
+DirectCompute::Reshaper::~Reshaper()
+{
+ if( constantBuffer )
+ csSetCB( nullptr );
+} \ No newline at end of file