Skip to content

Commit 46e190d

Browse files
committed
feat: add RNN basic framework.
1 parent e9f2cac commit 46e190d

File tree

88 files changed

+1789
-188
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

88 files changed

+1789
-188
lines changed

src/TensorFlowNET.Core/Extensions/JObjectExtensions.cs renamed to src/TensorFlowNET.Core/Common/Extensions/JObjectExtensions.cs

+3-3
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,16 @@
33
using System.Collections.Generic;
44
using System.Text;
55

6-
namespace Tensorflow.Extensions
6+
namespace Tensorflow.Common.Extensions
77
{
88
public static class JObjectExtensions
99
{
1010
public static T? TryGetOrReturnNull<T>(this JObject obj, string key)
1111
{
1212
var res = obj[key];
13-
if(res is null)
13+
if (res is null)
1414
{
15-
return default(T);
15+
return default;
1616
}
1717
else
1818
{
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using System.Text;
5+
6+
namespace Tensorflow.Common.Extensions
7+
{
8+
public static class LinqExtensions
9+
{
10+
#if NETSTANDARD2_0
11+
public static IEnumerable<T> TakeLast<T>(this IEnumerable<T> sequence, int count)
12+
{
13+
return sequence.Skip(sequence.Count() - count);
14+
}
15+
16+
public static IEnumerable<T> SkipLast<T>(this IEnumerable<T> sequence, int count)
17+
{
18+
return sequence.Take(sequence.Count() - count);
19+
}
20+
#endif
21+
public static Tensors ToTensors(this IEnumerable<Tensor> tensors)
22+
{
23+
return new Tensors(tensors);
24+
}
25+
}
26+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Diagnostics;
4+
using System.Text;
5+
6+
namespace Tensorflow.Common.Types
7+
{
8+
public class GeneralizedTensorShape: IEnumerable<long?[]>
9+
{
10+
public TensorShapeConfig[] Shapes { get; set; }
11+
/// <summary>
12+
/// create a single-dim generalized Tensor shape.
13+
/// </summary>
14+
/// <param name="dim"></param>
15+
public GeneralizedTensorShape(int dim)
16+
{
17+
Shapes = new TensorShapeConfig[] { new TensorShapeConfig() { Items = new long?[] { dim } } };
18+
}
19+
20+
public GeneralizedTensorShape(Shape shape)
21+
{
22+
Shapes = new TensorShapeConfig[] { shape };
23+
}
24+
25+
public GeneralizedTensorShape(TensorShapeConfig shape)
26+
{
27+
Shapes = new TensorShapeConfig[] { shape };
28+
}
29+
30+
public GeneralizedTensorShape(TensorShapeConfig[] shapes)
31+
{
32+
Shapes = shapes;
33+
}
34+
35+
public GeneralizedTensorShape(IEnumerable<Shape> shape)
36+
{
37+
Shapes = shape.Select(x => (TensorShapeConfig)x).ToArray();
38+
}
39+
40+
public Shape ToSingleShape()
41+
{
42+
if (Shapes.Length != 1)
43+
{
44+
throw new ValueError("The generalized shape contains more than 1 dim.");
45+
}
46+
var shape_config = Shapes[0];
47+
Debug.Assert(shape_config is not null);
48+
return new Shape(shape_config.Items.Select(x => x is null ? -1 : x.Value).ToArray());
49+
}
50+
51+
public long ToNumber()
52+
{
53+
if(Shapes.Length != 1 || Shapes[0].Items.Length != 1)
54+
{
55+
throw new ValueError("The generalized shape contains more than 1 dim.");
56+
}
57+
var res = Shapes[0].Items[0];
58+
return res is null ? -1 : res.Value;
59+
}
60+
61+
public Shape[] ToShapeArray()
62+
{
63+
return Shapes.Select(x => new Shape(x.Items.Select(y => y is null ? -1 : y.Value).ToArray())).ToArray();
64+
}
65+
66+
public IEnumerator<long?[]> GetEnumerator()
67+
{
68+
foreach (var shape in Shapes)
69+
{
70+
yield return shape.Items;
71+
}
72+
}
73+
74+
IEnumerator IEnumerable.GetEnumerator()
75+
{
76+
return GetEnumerator();
77+
}
78+
}
79+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Common.Types
6+
{
7+
/// <summary>
8+
/// This interface is used when some corresponding python methods have optional args.
9+
/// For example, `Keras.Layer.Apply` generally takes three args as the inputs, while
10+
/// `Keras.Layer.RNN` takes more. Then when calling RNN, you should add `RnnOptionalArgs`
11+
/// as the parameter of the method.
12+
/// </summary>
13+
public interface IOptionalArgs
14+
{
15+
/// <summary>
16+
/// The identifier of the class. It is not an argument but only something to
17+
/// separate different OptionalArgs.
18+
/// </summary>
19+
string Identifier { get; }
20+
}
21+
}

src/TensorFlowNET.Core/Keras/Saving/TensorShapeConfig.cs renamed to src/TensorFlowNET.Core/Common/Types/TensorShapeConfig.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
using System.Collections.Generic;
44
using System.Linq;
55

6-
namespace Tensorflow.Keras.Saving
6+
namespace Tensorflow.Common.Types
77
{
88
public class TensorShapeConfig
99
{

src/TensorFlowNET.Core/Keras/ArgsDefinition/Rnn/RNNArgs.cs

+6-5
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,15 @@
11
using Newtonsoft.Json;
22
using System.Collections.Generic;
3+
using Tensorflow.Keras.Layers.Rnn;
34

45
namespace Tensorflow.Keras.ArgsDefinition.Rnn
56
{
7+
// TODO(Rinne): add regularizers.
68
public class RNNArgs : AutoSerializeLayerArgs
79
{
8-
public interface IRnnArgCell : ILayer
9-
{
10-
object state_size { get; }
11-
}
1210
[JsonProperty("cell")]
1311
// TODO: the cell should be serialized with `serialize_keras_object`.
14-
public IRnnArgCell Cell { get; set; } = null;
12+
public IRnnCell Cell { get; set; } = null;
1513
[JsonProperty("return_sequences")]
1614
public bool ReturnSequences { get; set; } = false;
1715
[JsonProperty("return_state")]
@@ -34,6 +32,9 @@ public interface IRnnArgCell : ILayer
3432
public IInitializer KernelInitializer { get; set; }
3533
public IInitializer RecurrentInitializer { get; set; }
3634
public IInitializer BiasInitializer { get; set; }
35+
public float Dropout { get; set; } = .0f;
36+
public bool ZeroOutputForMask { get; set; } = false;
37+
public float RecurrentDropout { get; set; } = .0f;
3738

3839
// kernel_regularizer=None,
3940
// recurrent_regularizer=None,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using Tensorflow.Common.Types;
5+
6+
namespace Tensorflow.Keras.ArgsDefinition.Rnn
7+
{
8+
public class RnnOptionalArgs: IOptionalArgs
9+
{
10+
public string Identifier => "Rnn";
11+
public Tensor Mask { get; set; } = null;
12+
public Tensors Constants { get; set; } = null;
13+
}
14+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
using Newtonsoft.Json;
2+
using System;
3+
using System.Collections.Generic;
4+
using System.Text;
5+
6+
namespace Tensorflow.Keras.ArgsDefinition.Rnn
7+
{
8+
public class SimpleRNNCellArgs: AutoSerializeLayerArgs
9+
{
10+
[JsonProperty("units")]
11+
public int Units { get; set; }
12+
// TODO(Rinne): lack of initialized value of Activation. Merging keras
13+
// into tf.net could resolve it.
14+
[JsonProperty("activation")]
15+
public Activation Activation { get; set; }
16+
[JsonProperty("use_bias")]
17+
public bool UseBias { get; set; } = true;
18+
[JsonProperty("dropout")]
19+
public float Dropout { get; set; } = .0f;
20+
[JsonProperty("recurrent_dropout")]
21+
public float RecurrentDropout { get; set; } = .0f;
22+
[JsonProperty("kernel_initializer")]
23+
public IInitializer KernelInitializer { get; set; }
24+
[JsonProperty("recurrent_initializer")]
25+
public IInitializer RecurrentInitializer { get; set; }
26+
[JsonProperty("bias_initializer")]
27+
public IInitializer BiasInitializer { get; set; }
28+
}
29+
}

src/TensorFlowNET.Core/Keras/Layers/ILayer.cs

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
using Tensorflow.Keras.Engine;
1+
using Tensorflow.Common.Types;
2+
using Tensorflow.Keras.Engine;
23
using Tensorflow.Keras.Saving;
34
using Tensorflow.NumPy;
45
using Tensorflow.Training;
@@ -14,7 +15,7 @@ public interface ILayer: IWithTrackable, IKerasConfigable
1415
List<ILayer> Layers { get; }
1516
List<INode> InboundNodes { get; }
1617
List<INode> OutboundNodes { get; }
17-
Tensors Apply(Tensors inputs, Tensor state = null, bool training = false);
18+
Tensors Apply(Tensors inputs, Tensors states = null, bool training = false, IOptionalArgs? optional_args = null);
1819
List<IVariableV1> TrainableVariables { get; }
1920
List<IVariableV1> TrainableWeights { get; }
2021
List<IVariableV1> NonTrainableWeights { get; }
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using Tensorflow.Common.Types;
5+
6+
namespace Tensorflow.Keras.Layers.Rnn
7+
{
8+
public interface IRnnCell: ILayer
9+
{
10+
GeneralizedTensorShape StateSize { get; }
11+
GeneralizedTensorShape OutputSize { get; }
12+
/// <summary>
13+
/// Whether the optional RNN args are supported when appying the layer.
14+
/// In other words, whether `Apply` is overwrited with process of `RnnOptionalArgs`.
15+
/// </summary>
16+
bool SupportOptionalArgs { get; }
17+
(Tensor, Tensors) Call(Tensors inputs, Tensors states, bool? training = null);
18+
}
19+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
5+
namespace Tensorflow.Keras.Layers.Rnn
6+
{
7+
public interface IStackedRnnCells : IRnnCell
8+
{
9+
int Count { get; }
10+
IRnnCell this[int idx] { get; }
11+
}
12+
}

src/TensorFlowNET.Core/Keras/Saving/Json/CustomizedKerasShapesWrapperJsonConverter.cs

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
using System;
44
using System.Collections.Generic;
55
using System.Text;
6+
using Tensorflow.Common.Types;
67

78
namespace Tensorflow.Keras.Saving.Json
89
{

src/TensorFlowNET.Core/Keras/Saving/KerasShapesWrapper.cs

+1
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
using System.Diagnostics;
77
using OneOf.Types;
88
using Tensorflow.Keras.Saving.Json;
9+
using Tensorflow.Common.Types;
910

1011
namespace Tensorflow.Keras.Saving
1112
{

src/TensorFlowNET.Core/NumPy/Axis.cs

-5
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,3 @@ public override string ToString()
7474
=> IsScalar ? $"{axis[0]}" : $"({string.Join(", ", axis)})";
7575
}
7676
}
77-
78-
namespace System.Runtime.CompilerServices
79-
{
80-
internal static class IsExternalInit { }
81-
}

src/TensorFlowNET.Core/Operations/Initializers/Orthogonal.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ private Tensor _generate_init_val(Shape shape, TF_DataType dtype)
5353
// Compute the qr factorization
5454
var (q, r) = tf.linalg.qr(a, full_matrices: false);
5555
// Make Q uniform
56-
var d = tf.linalg.tensor_diag_part(r);
56+
var d = tf.linalg.tensor_diag_part(r.Single);
5757
q *= tf.sign(d);
5858

5959
if (num_rows < num_cols)

src/TensorFlowNET.Core/Operations/NnOps/BasicLSTMCell.cs

+1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ namespace Tensorflow
1111
/// Basic LSTM recurrent network cell.
1212
/// The implementation is based on: http://arxiv.org/abs/1409.2329.
1313
/// </summary>
14+
[Obsolete("This is an incompleted tf v1 api, pleas use keras RNNs instead.")]
1415
public class BasicLstmCell : LayerRnnCell
1516
{
1617
int _num_units;

src/TensorFlowNET.Core/Operations/NnOps/BasicRNNCell.cs

+1
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ limitations under the License.
2020

2121
namespace Tensorflow
2222
{
23+
[Obsolete("This is an incompleted tf v1 api, pleas use keras RNNs instead.")]
2324
public class BasicRnnCell : LayerRnnCell
2425
{
2526
int _num_units;

src/TensorFlowNET.Core/Operations/NnOps/LayerRNNCell.cs

+1
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ limitations under the License.
1919

2020
namespace Tensorflow
2121
{
22+
[Obsolete("This is an incompleted tf v1 api, pleas use keras RNNs instead.")]
2223
public class LayerRnnCell : RnnCell
2324
{
2425
protected InputSpec inputSpec;

src/TensorFlowNET.Core/Operations/NnOps/RNNCell.cs

+13-2
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,12 @@ limitations under the License.
1616

1717
using System;
1818
using System.Collections.Generic;
19+
using Tensorflow.Common.Types;
1920
using Tensorflow.Keras;
2021
using Tensorflow.Keras.ArgsDefinition;
2122
using Tensorflow.Keras.ArgsDefinition.Rnn;
2223
using Tensorflow.Keras.Engine;
24+
using Tensorflow.Keras.Layers.Rnn;
2325
using Tensorflow.Keras.Saving;
2426
using Tensorflow.NumPy;
2527
using Tensorflow.Operations;
@@ -50,7 +52,8 @@ namespace Tensorflow
5052
/// matching structure of Tensors having shape `[batch_size].concatenate(s)`
5153
/// for each `s` in `self.batch_size`.
5254
/// </summary>
53-
public abstract class RnnCell : ILayer, RNNArgs.IRnnArgCell
55+
[Obsolete("This is an incompleted tf v1 api, pleas use keras RNNs instead.")]
56+
public abstract class RnnCell : ILayer, IRnnCell
5457
{
5558
/// <summary>
5659
/// Attribute that indicates whether the cell is a TF RNN cell, due the slight
@@ -142,7 +145,7 @@ private Tensor _zero_state_tensors(object state_size, Tensor batch_size, TF_Data
142145
throw new NotImplementedException("_zero_state_tensors");
143146
}
144147

145-
public Tensors Apply(Tensors inputs, Tensor state = null, bool is_training = false)
148+
public Tensors Apply(Tensors inputs, Tensors state = null, bool is_training = false, IOptionalArgs? optional_args = null)
146149
{
147150
throw new NotImplementedException();
148151
}
@@ -173,5 +176,13 @@ public void adapt(Tensor data, int? batch_size = null, int? steps = null)
173176
{
174177
throw new NotImplementedException();
175178
}
179+
180+
public (Tensor, Tensors) Call(Tensors inputs, Tensors states, bool? training = null)
181+
{
182+
throw new NotImplementedException();
183+
}
184+
public GeneralizedTensorShape StateSize => throw new NotImplementedException();
185+
public GeneralizedTensorShape OutputSize => throw new NotImplementedException();
186+
public bool SupportOptionalArgs => throw new NotImplementedException();
176187
}
177188
}

src/TensorFlowNET.Core/Operations/logging_ops.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ public Tensor print_v2(Tensor input, string output_stream = "stderr", string end
3030
name: name);
3131

3232
return tf.Context.ExecuteOp("PrintV2", name, new ExecuteOpArgs(formatted_string)
33-
.SetAttributes(new { output_stream, end }));
33+
.SetAttributes(new { output_stream, end })).SingleOrNull;
3434
}
3535
}
3636
}

0 commit comments

Comments
 (0)