Skip to content

Commit cfffc68

Browse files
committed
Add C API for TF_StringInit.
1 parent c36d370 commit cfffc68

File tree

6 files changed

+125
-5
lines changed

6 files changed

+125
-5
lines changed

src/TensorFlowNET.Core/Tensorflow.Binding.csproj

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,11 @@ https://tensorflownet.readthedocs.io</Description>
2525
* Eager Mode is added finally.
2626
* tf.keras is partially working.
2727
* tf.data is added.
28-
* autograph works partially.
28+
* Autograph works partially.
29+
* Improve memory usage.
2930

30-
TensorFlow .NET v0.3x is focused on making more Keras API works</PackageReleaseNotes>
31+
TensorFlow .NET v0.3x is focused on making more Keras API works.
32+
Keras API is a separate package released as TensorFlow.Keras.</PackageReleaseNotes>
3133
<FileVersion>0.33.0.0</FileVersion>
3234
<PackageLicenseFile>LICENSE</PackageLicenseFile>
3335
<PackageRequireLicenseAcceptance>true</PackageRequireLicenseAcceptance>
@@ -83,7 +85,7 @@ TensorFlow .NET v0.3x is focused on making more Keras API works</PackageReleaseN
8385
<PackageReference Include="MethodBoundaryAspect.Fody" Version="2.0.138" />
8486
<PackageReference Include="Microsoft.Extensions.DependencyInjection" Version="5.0.1" />
8587
<PackageReference Include="NumSharp.Lite" Version="0.1.12" />
86-
<PackageReference Include="Protobuf.Text" Version="0.4.0" />
88+
<PackageReference Include="Protobuf.Text" Version="0.5.0" />
8789
<PackageReference Include="Serilog.Sinks.Console" Version="3.1.1" />
8890
</ItemGroup>
8991
</Project>
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow
6+
{
7+
public enum TF_TString_Type
8+
{
9+
TF_TSTR_SMALL = 0,
10+
TF_TSTR_LARGE = 1,
11+
TF_TSTR_OFFSET = 2,
12+
TF_TSTR_VIEW = 3
13+
}
14+
}
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using Tensorflow.Util;
5+
6+
namespace Tensorflow
7+
{
8+
public class TStringHandle : SafeTensorflowHandle
9+
{
10+
protected override bool ReleaseHandle()
11+
{
12+
c_api.TF_StringDealloc(handle);
13+
return true;
14+
}
15+
}
16+
}

src/TensorFlowNET.Core/Tensors/Tensor.String.cs

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,29 @@ namespace Tensorflow
88
{
99
public partial class Tensor
1010
{
11-
public unsafe IntPtr StringTensor(string[] strings, TensorShape shape)
11+
const ulong TF_TSRING_SIZE = 24;
12+
13+
public IntPtr StringTensor25(string[] strings, TensorShape shape)
14+
{
15+
var handle = c_api.TF_AllocateTensor(TF_DataType.TF_STRING,
16+
shape.dims.Select(x => (long)x).ToArray(),
17+
shape.ndim,
18+
(ulong)shape.size * TF_TSRING_SIZE);
19+
20+
var data = c_api.TF_TensorData(handle);
21+
var tstr = c_api.TF_StringInit(handle);
22+
// AllocationHandle = tstr;
23+
// AllocationType = AllocationType.Tensorflow;
24+
for (int i = 0; i< strings.Length; i++)
25+
{
26+
c_api.TF_StringCopy(tstr, strings[i], strings[i].Length);
27+
tstr += (int)TF_TSRING_SIZE;
28+
}
29+
// c_api.TF_StringDealloc(tstr);
30+
return handle;
31+
}
32+
33+
public IntPtr StringTensor(string[] strings, TensorShape shape)
1234
{
1335
// convert string array to byte[][]
1436
var buffer = new byte[strings.Length][];
@@ -61,11 +83,27 @@ public unsafe IntPtr StringTensor(byte[][] buffer, TensorShape shape)
6183
return handle;
6284
}
6385

86+
public string[] StringData25()
87+
{
88+
string[] strings = new string[c_api.TF_Dim(_handle, 0)];
89+
var tstrings = TensorDataPointer;
90+
for (int i = 0; i< strings.Length; i++)
91+
{
92+
var tstringData = c_api.TF_StringGetDataPointer(tstrings);
93+
/*var size = c_api.TF_StringGetSize(tstrings);
94+
var capacity = c_api.TF_StringGetCapacity(tstrings);
95+
var type = c_api.TF_StringGetType(tstrings);*/
96+
strings[i] = c_api.StringPiece(tstringData);
97+
tstrings += (int)TF_TSRING_SIZE;
98+
}
99+
return strings;
100+
}
101+
64102
/// <summary>
65103
/// Extracts string array from current Tensor.
66104
/// </summary>
67105
/// <exception cref="InvalidOperationException">When <see cref="dtype"/> != TF_DataType.TF_STRING</exception>
68-
public unsafe string[] StringData()
106+
public string[] StringData()
69107
{
70108
var buffer = StringBytes();
71109

src/TensorFlowNET.Core/Tensors/c_api.tensor.cs

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,30 @@ public static unsafe IntPtr TF_NewTensor(TF_DataType dataType, long[] dims, int
181181
[DllImport(TensorFlowLibName)]
182182
public static extern unsafe ulong TF_StringEncode(byte* src, ulong src_len, byte* dst, ulong dst_len, SafeStatusHandle status);
183183

184+
[DllImport(TensorFlowLibName)]
185+
public static extern IntPtr TF_StringInit(IntPtr t);
186+
187+
[DllImport(TensorFlowLibName)]
188+
public static extern void TF_StringCopy(IntPtr dst, string text, long size);
189+
190+
[DllImport(TensorFlowLibName)]
191+
public static extern void TF_StringAssignView(IntPtr dst, IntPtr text, long size);
192+
193+
[DllImport(TensorFlowLibName)]
194+
public static extern IntPtr TF_StringGetDataPointer(IntPtr tst);
195+
196+
[DllImport(TensorFlowLibName)]
197+
public static extern TF_TString_Type TF_StringGetType(IntPtr tst);
198+
199+
[DllImport(TensorFlowLibName)]
200+
public static extern ulong TF_StringGetSize(IntPtr tst);
201+
202+
[DllImport(TensorFlowLibName)]
203+
public static extern ulong TF_StringGetCapacity(IntPtr tst);
204+
205+
[DllImport(TensorFlowLibName)]
206+
public static extern void TF_StringDealloc(IntPtr tst);
207+
184208
/// <summary>
185209
/// Decode a string encoded using TF_StringEncode.
186210
/// </summary>

test/TensorFlowNET.Native.UnitTest/Tensors/TensorTest.cs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,32 @@ public void Tensor()
107107
Assert.IsTrue(Enumerable.SequenceEqual(nd.Data<float>(), new float[] { 1, 2, 3, 4, 5, 6 }));
108108
}
109109

110+
/// <summary>
111+
/// Port from c_api_test.cc
112+
/// `TEST_F(CApiAttributesTest, StringTensor)`
113+
/// </summary>
114+
[TestMethod, Ignore("Waiting for PR https://github.com/tensorflow/tensorflow/pull/46804")]
115+
public void StringTensor()
116+
{
117+
string text = "Hello world!.";
118+
119+
var tensor = c_api.TF_AllocateTensor(TF_DataType.TF_STRING,
120+
null,
121+
0,
122+
1 * 24);
123+
var tstr = c_api.TF_StringInit(tensor);
124+
var data = c_api.TF_StringGetDataPointer(tstr);
125+
c_api.TF_StringCopy(tstr, text, text.Length);
126+
127+
Assert.AreEqual((ulong)text.Length, c_api.TF_StringGetSize(tstr));
128+
Assert.AreEqual(text, c_api.StringPiece(data));
129+
Assert.AreEqual((ulong)text.Length, c_api.TF_TensorByteSize(tensor));
130+
Assert.AreEqual(0, c_api.TF_NumDims(tensor));
131+
132+
TF_DeleteTensor(tensor);
133+
c_api.TF_StringDealloc(tstr);
134+
}
135+
110136
/// <summary>
111137
/// Port from tensorflow\c\c_api_test.cc
112138
/// `TEST(CAPI, SetShape)`

0 commit comments

Comments
 (0)