Skip to content

Commit 9da157f

Browse files
authored
Merge pull request #1097 from AsakusaRinne/rnn-dev
feat: add rnn basic modules
2 parents f45b35b + 537b3e1 commit 9da157f

File tree

98 files changed

+2809
-293
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

98 files changed

+2809
-293
lines changed

src/TensorFlowNET.Core/Extensions/JObjectExtensions.cs renamed to src/TensorFlowNET.Core/Common/Extensions/JObjectExtensions.cs

+3-3
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,16 @@
33
using System.Collections.Generic;
44
using System.Text;
55

6-
namespace Tensorflow.Extensions
6+
namespace Tensorflow.Common.Extensions
77
{
88
public static class JObjectExtensions
99
{
1010
public static T? TryGetOrReturnNull<T>(this JObject obj, string key)
1111
{
1212
var res = obj[key];
13-
if(res is null)
13+
if (res is null)
1414
{
15-
return default(T);
15+
return default;
1616
}
1717
else
1818
{
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using System.Text;
5+
6+
namespace Tensorflow.Common.Extensions
7+
{
8+
public static class LinqExtensions
9+
{
10+
#if NETSTANDARD2_0
11+
public static IEnumerable<T> TakeLast<T>(this IEnumerable<T> sequence, int count)
12+
{
13+
return sequence.Skip(sequence.Count() - count);
14+
}
15+
16+
public static IEnumerable<T> SkipLast<T>(this IEnumerable<T> sequence, int count)
17+
{
18+
return sequence.Take(sequence.Count() - count);
19+
}
20+
#endif
21+
public static Tensors ToTensors(this IEnumerable<Tensor> tensors)
22+
{
23+
return new Tensors(tensors);
24+
}
25+
26+
public static void Deconstruct<T1, T2, T3>(this (T1, T2, T3) values, out T1 first, out T2 second, out T3 third)
27+
{
28+
first = values.Item1;
29+
second = values.Item2;
30+
third = values.Item3;
31+
}
32+
}
33+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using Tensorflow.Common.Types;
5+
6+
namespace Tensorflow.Common.Extensions
7+
{
8+
public static class NestExtensions
9+
{
10+
public static Tensors ToTensors(this INestable<Tensor> tensors)
11+
{
12+
return new Tensors(tensors.AsNest());
13+
}
14+
15+
public static Tensors? ToTensors(this Nest<Tensor> tensors)
16+
{
17+
return Tensors.FromNest(tensors);
18+
}
19+
20+
/// <summary>
21+
/// If the nested object is already a nested type, this function could reduce it.
22+
/// For example, `Nest[Nest[T]]` can be reduced to `Nest[T]`.
23+
/// </summary>
24+
/// <typeparam name="TIn"></typeparam>
25+
/// <typeparam name="TOut"></typeparam>
26+
/// <param name="input"></param>
27+
/// <returns></returns>
28+
public static Nest<TOut> ReduceTo<TIn, TOut>(this INestStructure<TIn> input) where TIn: INestStructure<TOut>
29+
{
30+
return Nest<TOut>.ReduceFrom(input);
31+
}
32+
}
33+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Diagnostics;
4+
using System.Text;
5+
6+
namespace Tensorflow.Common.Types
7+
{
8+
public class GeneralizedTensorShape: IEnumerable<long?[]>, INestStructure<long?>, INestable<long?>
9+
{
10+
public TensorShapeConfig[] Shapes { get; set; }
11+
/// <summary>
12+
/// create a single-dim generalized Tensor shape.
13+
/// </summary>
14+
/// <param name="dim"></param>
15+
public GeneralizedTensorShape(int dim)
16+
{
17+
Shapes = new TensorShapeConfig[] { new TensorShapeConfig() { Items = new long?[] { dim } } };
18+
}
19+
20+
public GeneralizedTensorShape(Shape shape)
21+
{
22+
Shapes = new TensorShapeConfig[] { shape };
23+
}
24+
25+
public GeneralizedTensorShape(TensorShapeConfig shape)
26+
{
27+
Shapes = new TensorShapeConfig[] { shape };
28+
}
29+
30+
public GeneralizedTensorShape(TensorShapeConfig[] shapes)
31+
{
32+
Shapes = shapes;
33+
}
34+
35+
public GeneralizedTensorShape(IEnumerable<Shape> shape)
36+
{
37+
Shapes = shape.Select(x => (TensorShapeConfig)x).ToArray();
38+
}
39+
40+
public Shape ToSingleShape()
41+
{
42+
if (Shapes.Length != 1)
43+
{
44+
throw new ValueError("The generalized shape contains more than 1 dim.");
45+
}
46+
var shape_config = Shapes[0];
47+
Debug.Assert(shape_config is not null);
48+
return new Shape(shape_config.Items.Select(x => x is null ? -1 : x.Value).ToArray());
49+
}
50+
51+
public long ToNumber()
52+
{
53+
if(Shapes.Length != 1 || Shapes[0].Items.Length != 1)
54+
{
55+
throw new ValueError("The generalized shape contains more than 1 dim.");
56+
}
57+
var res = Shapes[0].Items[0];
58+
return res is null ? -1 : res.Value;
59+
}
60+
61+
public Shape[] ToShapeArray()
62+
{
63+
return Shapes.Select(x => new Shape(x.Items.Select(y => y is null ? -1 : y.Value).ToArray())).ToArray();
64+
}
65+
66+
public IEnumerable<long?> Flatten()
67+
{
68+
List<long?> result = new List<long?>();
69+
foreach(var shapeConfig in Shapes)
70+
{
71+
result.AddRange(shapeConfig.Items);
72+
}
73+
return result;
74+
}
75+
public INestStructure<TOut> MapStructure<TOut>(Func<long?, TOut> func)
76+
{
77+
List<Nest<TOut>> lists = new();
78+
foreach(var shapeConfig in Shapes)
79+
{
80+
lists.Add(new Nest<TOut>(shapeConfig.Items.Select(x => new Nest<TOut>(func(x)))));
81+
}
82+
return new Nest<TOut>(lists);
83+
}
84+
85+
public Nest<long?> AsNest()
86+
{
87+
Nest<long?> DealWithSingleShape(TensorShapeConfig config)
88+
{
89+
if (config.Items.Length == 0)
90+
{
91+
return Nest<long?>.Empty;
92+
}
93+
else if (config.Items.Length == 1)
94+
{
95+
return new Nest<long?>(config.Items[0]);
96+
}
97+
else
98+
{
99+
return new Nest<long?>(config.Items.Select(x => new Nest<long?>(x)));
100+
}
101+
}
102+
103+
if(Shapes.Length == 0)
104+
{
105+
return Nest<long?>.Empty;
106+
}
107+
else if(Shapes.Length == 1)
108+
{
109+
return DealWithSingleShape(Shapes[0]);
110+
}
111+
else
112+
{
113+
return new Nest<long?>(Shapes.Select(s => DealWithSingleShape(s)));
114+
}
115+
}
116+
117+
public IEnumerator<long?[]> GetEnumerator()
118+
{
119+
foreach (var shape in Shapes)
120+
{
121+
yield return shape.Items;
122+
}
123+
}
124+
125+
IEnumerator IEnumerable.GetEnumerator()
126+
{
127+
return GetEnumerator();
128+
}
129+
}
130+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Common.Types
6+
{
7+
/// <summary>
8+
/// This interface indicates that a class may have a nested structure and provide
9+
/// methods to manipulate with the structure.
10+
/// </summary>
11+
public interface INestStructure<T>: INestable<T>
12+
{
13+
/// <summary>
14+
/// Flatten the Nestable object. Node that if the object contains only one value,
15+
/// it will be flattened to an enumerable with one element.
16+
/// </summary>
17+
/// <returns></returns>
18+
IEnumerable<T> Flatten();
19+
/// <summary>
20+
/// Construct a new object with the same nested structure.
21+
/// </summary>
22+
/// <typeparam name="TOut"></typeparam>
23+
/// <param name="func"></param>
24+
/// <returns></returns>
25+
INestStructure<TOut> MapStructure<TOut>(Func<T, TOut> func);
26+
}
27+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Common.Types
6+
{
7+
public interface INestable<T>
8+
{
9+
Nest<T> AsNest();
10+
}
11+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Common.Types
6+
{
7+
/// <summary>
8+
/// This interface is used when some corresponding python methods have optional args.
9+
/// For example, `Keras.Layer.Apply` generally takes three args as the inputs, while
10+
/// `Keras.Layer.RNN` takes more. Then when calling RNN, you should add `RnnOptionalArgs`
11+
/// as the parameter of the method.
12+
/// </summary>
13+
public interface IOptionalArgs
14+
{
15+
/// <summary>
16+
/// The identifier of the class. It is not an argument but only something to
17+
/// separate different OptionalArgs.
18+
/// </summary>
19+
string Identifier { get; }
20+
}
21+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Common.Types
6+
{
7+
public static class Nest
8+
{
9+
/// <summary>
10+
/// Pack the flat items to a nested sequence by the template.
11+
/// </summary>
12+
/// <typeparam name="T"></typeparam>
13+
/// <param name="template"></param>
14+
/// <param name="flatItems"></param>
15+
/// <returns></returns>
16+
public static Nest<T> PackSequenceAs<T>(INestable<T> template, T[] flatItems)
17+
{
18+
return template.AsNest().PackSequence(flatItems);
19+
}
20+
21+
/// <summary>
22+
/// Pack the flat items to a nested sequence by the template.
23+
/// </summary>
24+
/// <typeparam name="T"></typeparam>
25+
/// <param name="template"></param>
26+
/// <param name="flatItems"></param>
27+
/// <returns></returns>
28+
public static Nest<T> PackSequenceAs<T>(INestable<T> template, List<T> flatItems)
29+
{
30+
return template.AsNest().PackSequence(flatItems.ToArray());
31+
}
32+
33+
/// <summary>
34+
/// Flatten the nested object.
35+
/// </summary>
36+
/// <typeparam name="T"></typeparam>
37+
/// <param name="nestedObject"></param>
38+
/// <returns></returns>
39+
public static IEnumerable<T> Flatten<T>(INestable<T> nestedObject)
40+
{
41+
return nestedObject.AsNest().Flatten();
42+
}
43+
44+
/// <summary>
45+
/// Map the structure with specified function.
46+
/// </summary>
47+
/// <typeparam name="TIn"></typeparam>
48+
/// <typeparam name="TOut"></typeparam>
49+
/// <param name="func"></param>
50+
/// <param name="nestedObject"></param>
51+
/// <returns></returns>
52+
public static INestStructure<TOut> MapStructure<TIn, TOut>(Func<TIn, TOut> func, INestable<TIn> nestedObject)
53+
{
54+
return nestedObject.AsNest().MapStructure(func);
55+
}
56+
57+
public static bool IsNested<T>(INestable<T> obj)
58+
{
59+
return obj.AsNest().IsNested();
60+
}
61+
}
62+
}

0 commit comments

Comments
 (0)