summaryrefslogtreecommitdiffstats
path: root/Whisper/ML/TensorEx.cpp
blob: 97e4e30c0b24d040971c1d9d8d2bb4476699000c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
#include "stdafx.h"
#include "TensorEx.h"
#include "../D3D/createBuffer.h"
#include "../source/ggml.h"
#include "../D3D/MappedResource.h"
using namespace DirectCompute;

HRESULT TensorEx::create( const ggml_tensor& ggml, eBufferUse usage, bool uploadData )
{
	TensorGpuViews::clear();
	buffer = nullptr;
	stagingBuffer = nullptr;

	CHECK( TensorShape::create( ggml ) );
	const ggml_type dataType = ggml.type;
	const uint32_t cbElement = (uint32_t)ggml_type_size( dataType );

	const size_t totalBytes = ggml_nbytes( &ggml );
	if( totalBytes > INT_MAX )
		return DISP_E_OVERFLOW;
	const uint32_t countElements = (uint32_t)( totalBytes / cbElement );

	{
		const void* const rsi = uploadData ? ggml.data : nullptr;
		ID3D11Buffer** ppStagingBuffer = ( usage == eBufferUse::ReadWriteDownload ) ? &stagingBuffer : nullptr;
		CHECK( createBuffer( usage, totalBytes, &buffer, rsi, ppStagingBuffer ) );
	}

	DXGI_FORMAT format;
	switch( dataType )
	{
	case GGML_TYPE_F16:
		format = DXGI_FORMAT_R16_FLOAT;
		break;
	case GGML_TYPE_F32:
		format = DXGI_FORMAT_R32_FLOAT;
		break;
	default:
		return E_NOTIMPL;
	}

	const bool makeUav = usage == eBufferUse::ReadWrite || usage == eBufferUse::ReadWriteDownload;
	return TensorGpuViews::create( buffer, format, totalBytes / cbElement, makeUav );
}

HRESULT TensorEx::create( eDataType type, eBufferUse usage, const std::array<uint32_t, 4>& sizeElements )
{
	TensorGpuViews::clear();
	buffer = nullptr;
	stagingBuffer = nullptr;
	std::initializer_list<uint32_t> il( sizeElements.data(), sizeElements.data() + 4 );

	ID3D11Buffer** ppStaging = ( usage == eBufferUse::ReadWriteDownload ) ? &stagingBuffer : nullptr;
	return Tensor::create( type, il, usage, buffer, nullptr, ppStaging );
}

HRESULT TensorEx::getViewSize( uint32_t& cbElement, uint32_t& countElements ) const
{
	ID3D11ShaderResourceView* const srv = *this;
	if( nullptr == srv )
		return OLE_E_BLANK;

	D3D11_SHADER_RESOURCE_VIEW_DESC viewDesc;
	srv->GetDesc( &viewDesc );

	cbElement = dxgiSizeof( viewDesc.Format );

	assert( viewDesc.ViewDimension == D3D_SRV_DIMENSION_BUFFER );
	assert( viewDesc.Buffer.FirstElement == 0 );
	countElements = viewDesc.Buffer.NumElements;

	return S_OK;
}

HRESULT TensorEx::download( void* rdi, size_t cb ) const
{
	if( nullptr == stagingBuffer )
		return HRESULT_FROM_WIN32( ERROR_GPIO_OPERATION_DENIED );	// The requested operation is not supported for the specified handle.

	ID3D11DeviceContext* const ctx = context();
	ctx->CopyResource( stagingBuffer, buffer );

	MappedResource mapped;
	CHECK( mapped.map( stagingBuffer, true ) );
	memcpy( rdi, mapped.data(), cb );

	return S_OK;
}

HRESULT TensorEx::download( void* rdi ) const
{
	uint32_t cbElement, numElements;
	CHECK( getViewSize( cbElement, numElements ) );

	size_t cb = (size_t)cbElement * numElements;
	return download( rdi, cb );
}