From 0360fbb5304813735d3fe36f7c46b887be139900 Mon Sep 17 00:00:00 2001 From: Yaohui Liu Date: Thu, 13 Apr 2023 02:54:05 +0800 Subject: [PATCH] Fix the error when saving model with GPU. --- .../Checkpoint/SaveUtilV1.cs | 13 +- .../Checkpoint/checkpoint.cs | 20 +- .../Checkpoint/functional_saver.cs | 129 ++++++----- .../Contexts/Context.Device.cs | 73 +++++++ src/TensorFlowNET.Core/Contexts/Context.cs | 5 +- .../Contexts/EagerDeviceContext.cs | 71 ++++++ src/TensorFlowNET.Core/Device/DeviceSpec.cs | 205 ++++++++++++++++++ src/TensorFlowNET.Core/Device/DeviceUtils.cs | 26 +++ src/TensorFlowNET.Core/Graphs/Graph.cs | 11 +- .../Graphs/GraphDeviceContext.cs | 31 +++ .../Saving/ResourceVariableSaveable.cs | 20 +- .../Training/Saving/SavedModel/loader.cs | 24 +- .../Variables/BaseResourceVariable.cs | 6 +- .../Variables/UninitializedVariable.cs | 9 +- src/TensorFlowNET.Core/ops.cs | 17 ++ 15 files changed, 568 insertions(+), 92 deletions(-) create mode 100644 src/TensorFlowNET.Core/Contexts/EagerDeviceContext.cs create mode 100644 src/TensorFlowNET.Core/Device/DeviceSpec.cs create mode 100644 src/TensorFlowNET.Core/Device/DeviceUtils.cs create mode 100644 src/TensorFlowNET.Core/Graphs/GraphDeviceContext.cs diff --git a/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs b/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs index 72372e410..5cda317bf 100644 --- a/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs +++ b/src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs @@ -53,8 +53,11 @@ public static (IList, IDictionary + { + // TODO(Rinne): locate the error that causes transferring TF_STRING to this function throws an exception. + return constant_op.constant(graph_proto.ToByteArray()); + }); named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); g.Exit(); return (named_saveable_objects, registered_savers); @@ -65,8 +68,10 @@ public static (IList, IDictionary + { + return constant_op.constant(graph_proto.ToString()); + }); named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY)); return (named_saveable_objects, registered_savers); } diff --git a/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs b/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs index 1934ffd5f..53b13d203 100644 --- a/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs +++ b/src/TensorFlowNET.Core/Checkpoint/checkpoint.cs @@ -58,8 +58,10 @@ public TrackableSaver(ObjectGraphView graph_view) if(object_graph_tensor is null) { - tf.device("/cpu:0"); - object_graph_tensor = constant_op.constant(graph_proto.ToByteArray()); + tf_with(ops.device("/cpu:0"), _ => + { + object_graph_tensor = constant_op.constant(graph_proto.ToByteArray()); + }); } else { @@ -230,13 +232,15 @@ public LoadStatus restore(string? save_path, CheckpointOptions? options = null) Tensor object_graph_string = reader.GetTensor(Trackable.Constants.OBJECT_GRAPH_PROTO_KEY, dtype: TF_DataType.TF_STRING); Dictionary file_prefix_feed_dict; - Tensor file_prefix_tensor; + Tensor file_prefix_tensor = null; if (graph_building) { if(_file_prefix_placeholder is null) { - tf.device("/cpu:0"); - _file_prefix_placeholder = constant_op.constant("model"); + _file_prefix_placeholder = tf_with(ops.device("/cpu:0"), _ => + { + return constant_op.constant("model"); + }); } file_prefix_tensor = _file_prefix_placeholder; file_prefix_feed_dict = new(); @@ -244,8 +248,10 @@ public LoadStatus restore(string? save_path, CheckpointOptions? options = null) } else { - tf.device("/cpu:0"); - file_prefix_tensor = constant_op.constant(save_path); + file_prefix_tensor = tf_with(ops.device("/cpu:0"), _ => + { + return constant_op.constant(save_path); + }); file_prefix_feed_dict = null; } TrackableObjectGraph object_graph_proto = new(); diff --git a/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs b/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs index 96e6c8dd9..05d947497 100644 --- a/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs +++ b/src/TensorFlowNET.Core/Checkpoint/functional_saver.cs @@ -211,9 +211,11 @@ public IDictionary> restore(Tensor file_pref string restore_device = string.IsNullOrEmpty(options.experimental_io_device) ? "cpu:0": options.experimental_io_device!; - // tf python has code `with ops.device(restore_device):` here. - tf.device(restore_device); // may be risky. - var restored_tensors = gen_ops.restore_v2(file_prefix, tensor_names.ToArray(), slice_specs.ToArray(), tensor_dtypes.ToArray()); + Tensor[] restored_tensors = null; + tf_with(ops.device(restore_device), _ => + { + restored_tensors = gen_ops.restore_v2(file_prefix, tensor_names.ToArray(), slice_specs.ToArray(), tensor_dtypes.ToArray()); + }); Dictionary> restored_tensor_dict = new(); int idx = 0; @@ -338,11 +340,14 @@ public Operation save(Tensor file_prefix, CheckpointOptions? options= null) options = new CheckpointOptions(); } - tf.device("CPU"); // may be risky. - var sharded_suffix = array_ops.where(gen_ops.regex_full_match(file_prefix, tf.constant(@"^s3://.*")), + Tensor tmp_checkpoint_prefix = null; + tf_with(ops.device("CPU"), _ => + { + var sharded_suffix = array_ops.where(gen_ops.regex_full_match(file_prefix, tf.constant(@"^s3://.*")), constant_op.constant(".part"), constant_op.constant("_temp/part")); - var tmp_checkpoint_prefix = gen_ops.string_join(new Tensor[] { file_prefix, sharded_suffix }); - IDictionary registered_paths = _registered_savers.Keys.ToDictionary(x => x, x => registered_saver_filename(file_prefix, x)); + tmp_checkpoint_prefix = gen_ops.string_join(new Tensor[] { file_prefix, sharded_suffix }); + IDictionary registered_paths = _registered_savers.Keys.ToDictionary(x => x, x => registered_saver_filename(file_prefix, x)); + }); Operation save_fn() { @@ -364,16 +369,24 @@ Operation save_fn() var saver = pair.Value; last_device = device; // skip the extra process of device name because of lack of API. - tf.device(device); - var shard_prefix = sharded_filename(tmp_checkpoint_prefix, shard, num_shards_tensor); + Tensor shard_prefix = null; + tf_with(ops.device(device), _ => + { + shard_prefix = sharded_filename(tmp_checkpoint_prefix, shard, num_shards_tensor); + }); saved_prefixes.Add(shard_prefix); - sharded_saves.Add(saver.save(shard_prefix, options)); + tf_with(ops.device(device), _ => + { + sharded_saves.Add(saver.save(shard_prefix, options)); + }); } using (var controller = ops.control_dependencies(sharded_saves.ToArray())) { string merge_device = string.IsNullOrEmpty(options.experimental_io_device) ? last_device : options.experimental_io_device; - tf.device(merge_device); - return gen_ops.merge_v2_checkpoints(saved_prefixes.ToArray(), tf.constant(file_prefix), delete_old_dirs: true); + return tf_with(ops.device(merge_device), _ => + { + return gen_ops.merge_v2_checkpoints(saved_prefixes.ToArray(), tf.constant(file_prefix), delete_old_dirs: true); + }); } } @@ -407,54 +420,56 @@ IDictionary restore_func() { var device = single_saver.Key; var saver = single_saver.Value; - tf.device(device); - var restored_tensor_dict = saver.restore(file_prefix, options); - - foreach(var pair in restored_tensor_dict) + tf_with(ops.device(device), _ => { - var checkpoint_key = pair.Key; - var slice_and_tensor = pair.Value; - foreach(var item in slice_and_tensor) + var restored_tensor_dict = saver.restore(file_prefix, options); + + foreach (var pair in restored_tensor_dict) { - var slice_spec = item.Key; - var tensor = item.Value; - var restore_fn = _keys_to_restore_fn[(checkpoint_key, slice_spec)]; - var internal_dict = restore_fn_inputs.SetDefault(restore_fn, new Dictionary>>()); - if (!string.IsNullOrEmpty(slice_spec)) + var checkpoint_key = pair.Key; + var slice_and_tensor = pair.Value; + foreach (var item in slice_and_tensor) { - if (!internal_dict.ContainsKey(checkpoint_key)) + var slice_spec = item.Key; + var tensor = item.Value; + var restore_fn = _keys_to_restore_fn[(checkpoint_key, slice_spec)]; + var internal_dict = restore_fn_inputs.SetDefault(restore_fn, new Dictionary>>()); + if (!string.IsNullOrEmpty(slice_spec)) { - Dictionary dict = new(); - dict[slice_spec] = tensor; - internal_dict[checkpoint_key] = new Maybe>(dict); + if (!internal_dict.ContainsKey(checkpoint_key)) + { + Dictionary dict = new(); + dict[slice_spec] = tensor; + internal_dict[checkpoint_key] = new Maybe>(dict); + } + else + { + internal_dict[checkpoint_key].GetValue>()[slice_spec] = tensor; + } } else { - internal_dict[checkpoint_key].GetValue>()[slice_spec] = tensor; + internal_dict[checkpoint_key] = new Maybe>(tensor); } - } - else - { - internal_dict[checkpoint_key] = new Maybe>(tensor); - } - restore_fn_input_count[restore_fn]--; + restore_fn_input_count[restore_fn]--; - if (restore_fn_input_count[restore_fn] == 0) - { - Dictionary>> restored_tensors = new(); - foreach(var input in restore_fn_inputs[restore_fn]) + if (restore_fn_input_count[restore_fn] == 0) { - restored_tensors[TrackableUtils.extract_local_name(input.Key)] = input.Value; - } - var ret = restore_fn.DynamicInvoke(restored_tensors); - if(ret is IDictionary) - { - var dict = (IDictionary)ret; - restore_ops = restore_ops.Concat(dict).ToDictionary(x => x.Key, x => x.Value); + Dictionary>> restored_tensors = new(); + foreach (var input in restore_fn_inputs[restore_fn]) + { + restored_tensors[TrackableUtils.extract_local_name(input.Key)] = input.Value; + } + var ret = restore_fn.DynamicInvoke(restored_tensors); + if (ret is IDictionary) + { + var dict = (IDictionary)ret; + restore_ops = restore_ops.Concat(dict).ToDictionary(x => x.Key, x => x.Value); + } } } } - } + }); } foreach(var item in _registered_savers) @@ -500,21 +515,25 @@ public SaverDef to_proto() private Tensor _traced_save(Tensor file_prefix) { var save_op = save(file_prefix); - tf.device("cpu:0"); - using (ops.control_dependencies(new object[]{ save_op })) + return tf_with(ops.device("cpu:0"), _ => { - return array_ops.identity(file_prefix); - } + return tf_with(ops.control_dependencies(new object[] { save_op }), __ => + { + return array_ops.identity(file_prefix); + }); + }); } private Tensor _traced_restore(Tensor file_prefix) { var restore_op = restore(file_prefix); - tf.device("cpu:0"); - using (ops.control_dependencies(restore_op.Values.ToArray())) + return tf_with(ops.device("cpu:0"), _ => { - return array_ops.identity(file_prefix); - } + return tf_with(ops.control_dependencies(restore_op.Values.ToArray()), __ => + { + return array_ops.identity(file_prefix); + }); + }); } public static MultiDeviceSaver from_saveables(IEnumerable saveables, IDictionary>? registered_savers = null, bool call_with_mapped_captures = false) diff --git a/src/TensorFlowNET.Core/Contexts/Context.Device.cs b/src/TensorFlowNET.Core/Contexts/Context.Device.cs index 97c550e8e..32e6682e0 100644 --- a/src/TensorFlowNET.Core/Contexts/Context.Device.cs +++ b/src/TensorFlowNET.Core/Contexts/Context.Device.cs @@ -21,6 +21,7 @@ limitations under the License. using static Tensorflow.Binding; using Google.Protobuf; using Tensorflow.Device; +using Tensorflow.Exceptions; using System.Collections.Generic; namespace Tensorflow.Contexts @@ -30,10 +31,30 @@ namespace Tensorflow.Contexts /// public sealed partial class Context { + internal static Dictionary<(string, string), (string, DeviceSpec)> _device_parsing_cache = new(); + internal List _logical_devices = null; + internal List _context_devices = null; + ContextDevicePlacementPolicy _device_policy; bool _log_device_placement; + int _num_gpus; Dictionary _memory_growth_map = new Dictionary(); + public string DeviceName { get; set; } = ""; + public DeviceSpec DeviceSpec { get; set; } = null; + + internal List Devices + { + get + { + if(_context_devices is null) + { + throw new AssertionError("Context must be initialized first."); + } + return _context_devices; + } + } + public void log_device_placement(bool enable) { if (_handle != null) @@ -89,5 +110,57 @@ public PhysicalDevice[] list_physical_devices(string device_type = null) return results.ToArray(); } + + public EagerDeviceContext device(string name) + { + return new EagerDeviceContext(this, name); + } + + internal void _set_device(string device_name, DeviceSpec device_spec) + { + DeviceSpec = device_spec; + DeviceName = device_name; + } + + internal void _initialize_logical_devices() + { + List logical_devices = new(); + List context_devices = new(); + Status status = new(); + var device_list = c_api.TFE_ContextListDevices(_handle, status); + status.Check(true); + try + { + this._num_gpus = 0; + string current_job = null; + int current_task = -1; + for(int i = 0; i < c_api.TF_DeviceListCount(device_list); i++) + { + var dev_name = c_api.TF_DeviceListName(device_list, i, status); + status.Check(true); + context_devices.Add(DeviceUtils.canonical_name(dev_name)); + var spec = DeviceSpec.from_string(dev_name); + if(spec.Job == "localhost") + { + spec = spec.replace(job: null, replica: -1, task: -1); + } + logical_devices.Add(new LogicalDevice(spec.ToString(), spec.DeviceType)); + var dev_type_memory = c_api.TF_DeviceListType(device_list, i, status); + var dev_type = c_api.StringPiece(dev_type_memory); + status.Check(true); + if(dev_type == "GPU" && spec.Job == current_job && spec.Task == current_task) + { + _num_gpus++; + } + } + } + finally + { + _logical_devices = logical_devices; + _context_devices = context_devices; + } + } } + + public record class LogicalDevice(string name, string device_type); } diff --git a/src/TensorFlowNET.Core/Contexts/Context.cs b/src/TensorFlowNET.Core/Contexts/Context.cs index 21a14831f..742e9ddf3 100644 --- a/src/TensorFlowNET.Core/Contexts/Context.cs +++ b/src/TensorFlowNET.Core/Contexts/Context.cs @@ -34,7 +34,6 @@ public sealed partial class Context public const int EAGER_MODE = 1; int defaultExecutionMode = EAGER_MODE; - public string DeviceName { get; set; } = ""; public string ScopeName { get; set; } = ""; bool initialized = false; ContextSwitchStack context_switches; @@ -62,6 +61,8 @@ public void ensure_initialized() if (initialized) return; + Debug.Assert(_context_devices is null); + Config = MergeConfig(); FunctionCallOptions.Config = Config; var config_str = Config.ToByteArray(); @@ -72,6 +73,7 @@ public void ensure_initialized() c_api.TFE_ContextOptionsSetDevicePlacementPolicy(opts, _device_policy); _handle = c_api.TFE_NewContext(opts, status); status.Check(true); + _initialize_logical_devices(); initialized = true; } @@ -174,6 +176,7 @@ public void reset_context() { c_api.TFE_ContextClearCaches(_handle); } + _device_parsing_cache.Clear(); } public static implicit operator SafeContextHandle(Context ctx) diff --git a/src/TensorFlowNET.Core/Contexts/EagerDeviceContext.cs b/src/TensorFlowNET.Core/Contexts/EagerDeviceContext.cs new file mode 100644 index 000000000..2d5f61cdb --- /dev/null +++ b/src/TensorFlowNET.Core/Contexts/EagerDeviceContext.cs @@ -0,0 +1,71 @@ +using System; +using System.Collections.Generic; +using System.Text; +using Tensorflow.Device; + +namespace Tensorflow.Contexts +{ + public class EagerDeviceContext : ITensorFlowObject + { + private Context _ctx; + private string _device_name; + private Stack<(string, DeviceSpec, DeviceSpec)> _stack; + + public EagerDeviceContext(Context ctx, string device_name) + { + _ctx = ctx; + _device_name = device_name; + _stack = new Stack<(string, DeviceSpec, DeviceSpec)>(); + } + public void __enter__() + { + var ctx = _ctx; + var old_device_name = ctx.DeviceName; + var old_device_spec = ctx.DeviceSpec; + var new_device_name = _device_name; + var cache_key = (old_device_name, new_device_name); + DeviceSpec new_device_spec; + if (Context._device_parsing_cache.ContainsKey(cache_key)) + { + (new_device_name, new_device_spec) = Context._device_parsing_cache[cache_key]; + } + else + { + if(new_device_name is not null) + { + var device_spec = DeviceSpec.from_string(new_device_name); + if (!string.IsNullOrEmpty(old_device_name)) + { + new_device_spec = new DeviceSpec(old_device_spec); + } + else + { + ctx.ensure_initialized(); + new_device_spec = DeviceSpec.from_string(ctx._context_devices[0]); + } + new_device_spec = new_device_spec.make_merged_spec(device_spec); + } + else + { + new_device_spec = DeviceSpec.from_string(ctx._context_devices[0]); + } + new_device_name = new_device_spec.ToString(); + Context._device_parsing_cache[cache_key] = (new_device_name, new_device_spec); + } + ctx._set_device(new_device_name, new_device_spec); + _stack.Push((old_device_name, old_device_spec, new_device_spec)); + } + + public void __exit__() + { + var ctx = _ctx; + var (old_device_name, old_device_spec, new_device_spec) = _stack.Pop(); + ctx._set_device(old_device_name, old_device_spec); + } + + public void Dispose() + { + + } + } +} diff --git a/src/TensorFlowNET.Core/Device/DeviceSpec.cs b/src/TensorFlowNET.Core/Device/DeviceSpec.cs new file mode 100644 index 000000000..f4ea8cf05 --- /dev/null +++ b/src/TensorFlowNET.Core/Device/DeviceSpec.cs @@ -0,0 +1,205 @@ +using System; +using System.Collections.Generic; +using System.Text; +using System.Threading.Tasks; + +namespace Tensorflow.Device +{ + public class DeviceSpec + { + private static Dictionary _STRING_TO_COMPONENTS_CACHE = new(); + private static Dictionary _COMPONENTS_TO_STRING_CACHE = new(); + private string _job; + private int _replica; + private int _task; + private string _device_type; + private int _device_index; + private string _as_string; + + public string Job => _job; + public int Replica => _replica; + public int Task => _task; + public string DeviceType => _device_type; + public int DeviceIndex => _device_index; + + public DeviceSpec(string job = null, int replica = -1, int task = -1, + string device_type = null, int device_index = -1) + { + _job = job; + _replica = replica; + _task = task; + _device_type = device_type; + _device_index = device_index; + _as_string = _components_to_string(job, replica, task, device_type, _device_index); + + } + + public DeviceSpec(DeviceSpec other) + { + _job = other._job; + _replica = other._replica; + _task = other._task; + _device_type = other._device_type; + _device_index = other._device_index; + _as_string = other._as_string; + } + + protected DeviceSpec(Components com) + { + _job = com.Job; + _replica = com.Replica; + _task = com.Task; + _device_type = com.DeviceType; + _device_index = com.DeviceIndex; + _as_string = _components_to_string(_job, _replica, _task, _device_type, _device_index); + } + + public DeviceSpec replace(string job = null, int replica = -1, int task = -1, + string device_type = null, int device_index = -1) + { + job = job ?? _job; + replica = replica == -1 ? _replica : replica; + task = task == -1 ? _task : task; + device_type = device_type ?? _device_type; + device_index = device_index == -1 ? _device_index : device_index; + return new DeviceSpec(job, replica, task, device_type, device_index); + } + + public static DeviceSpec from_string(string spec) + { + var components = _string_to_components(spec); + return new DeviceSpec(components.Job, components.Replica, components.Task, components.DeviceType, components.DeviceIndex); + } + + public DeviceSpec make_merged_spec(DeviceSpec dev) + { + return new DeviceSpec(_get_combined_properties(dev)); + } + + private Components _get_combined_properties(DeviceSpec dev) + { + return new Components( + dev.Job ?? _job, + dev.Replica == -1 ? _replica : dev.Replica, + dev.Task == -1 ? _task : dev.Task, + dev.DeviceType ?? _device_type, + dev.DeviceIndex == -1 ? _device_index : dev.DeviceIndex + ); + } + + private static string _components_to_string(string job, int replica, int task, string device_type, int device_index) + { + var key = new Components(job, replica, task, device_type, device_index); + if(_COMPONENTS_TO_STRING_CACHE.TryGetValue(key, out var cache_result)) + { + return cache_result; + } + + StringBuilder output = new(); + if(job is not null) + { + output.Append($"/job:{job}"); + } + if(replica != -1) + { + output.Append($"/replica:{replica}"); + } + if(task != -1) + { + output.Append($"/task:{task}"); + } + if (device_type is not null) + { + string device_index_string = "*"; + if (device_index != -1) + { + device_index_string = device_index.ToString(); + } + output.Append($"/device:{device_type}:{device_index_string}"); + } + var result = output.ToString(); + _COMPONENTS_TO_STRING_CACHE[key] = result; + return result; + } + + private static Components _string_to_components(string spec) + { + if(_STRING_TO_COMPONENTS_CACHE.TryGetValue(spec, out var cached_result)) + { + return cached_result; + } + var raw_spec = spec; + var splits = spec.Split('/').Select(x => x.Split(':')); + var valid_device_types = _get_valid_device_types(); + string job = null, device_type = null; + int replica = -1, task = -1, device_index = -1; + foreach (var y in splits) + { + var ly = y.Length; + if (ly > 0) + { + if(ly == 2 && y[0] == "job") + { + job = y[1]; + } + else if(ly == 2 && y[0] == "replica") + { + replica = int.Parse(y[1]); + } + else if(ly == 2 && y[0] == "task") + { + task = int.Parse(y[1]); + } + else if((ly == 1 || ly == 2) && valid_device_types.Contains(y[0].ToUpper())) + { + if (device_type is not null) + { + throw new ValueError($"Multiple device types are not allowed " + + $"while parsing the device spec: {spec}."); + } + device_type = y[0].ToUpper(); + if(ly == 2 && y[1] != "*") + { + device_index = int.Parse(y[1]); + } + } + else if(ly == 3 && y[0] == "device") + { + if(device_type is not null) + { + throw new ValueError($"Multiple device types are not allowed " + + $"while parsing the device spec: {spec}."); + } + device_type = y[1]; + if (y[2] != "*") + { + device_index = int.Parse(y[2]); + } + } + else if (y[0] != "") + { + throw new ValueError($"Unknown attribute '{y[0]}' is encountered " + + $"while parsing the device spec: {spec}."); + } + } + } + + var output = new Components(job, replica, task, device_type, device_index); + _STRING_TO_COMPONENTS_CACHE[raw_spec] = output; + return output; + } + + private static HashSet _get_valid_device_types() + { + // TODO(Rinne): revise it to calling C API (need customized API). + return new HashSet(new string[] { "CPU", "GPU" }); + } + + public override string ToString() + { + return _as_string; + } + + protected record class Components(string Job, int Replica, int Task, string DeviceType, int DeviceIndex); + } +} diff --git a/src/TensorFlowNET.Core/Device/DeviceUtils.cs b/src/TensorFlowNET.Core/Device/DeviceUtils.cs new file mode 100644 index 000000000..8f11e6c8a --- /dev/null +++ b/src/TensorFlowNET.Core/Device/DeviceUtils.cs @@ -0,0 +1,26 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Device +{ + internal static class DeviceUtils + { + public static string canonical_name(string device) + { + if(device is null) + { + return ""; + } + return DeviceSpec.from_string(device).ToString(); + } + public static string canonical_name(DeviceSpec device) + { + if (device is null) + { + return ""; + } + return device.ToString(); + } + } +} diff --git a/src/TensorFlowNET.Core/Graphs/Graph.cs b/src/TensorFlowNET.Core/Graphs/Graph.cs index 98cad3b28..0c49efd7e 100644 --- a/src/TensorFlowNET.Core/Graphs/Graph.cs +++ b/src/TensorFlowNET.Core/Graphs/Graph.cs @@ -19,6 +19,7 @@ limitations under the License. using System.Collections.Generic; using System.Collections.Specialized; using System.Linq; +using Tensorflow.Graphs; using static Tensorflow.Binding; namespace Tensorflow @@ -294,9 +295,15 @@ public virtual Operation create_op(string op_type, Tensor[] inputs, TF_DataType[ return op; } - public void device(string device_name) + public ITensorFlowObject device(string device_name) { - + return new GraphDeviceContext(this, device_name); + } + + private void add_device_to_stack(string device_name, int offset = 0) + { + // TODO(Rinne): deal with device spec. + int total_offset = offset + 1; } private void _create_op_helper(Operation op, bool compute_device = true) diff --git a/src/TensorFlowNET.Core/Graphs/GraphDeviceContext.cs b/src/TensorFlowNET.Core/Graphs/GraphDeviceContext.cs new file mode 100644 index 000000000..2754c2b36 --- /dev/null +++ b/src/TensorFlowNET.Core/Graphs/GraphDeviceContext.cs @@ -0,0 +1,31 @@ +using System; +using System.Collections.Generic; +using System.Text; + +namespace Tensorflow.Graphs +{ + public class GraphDeviceContext : ITensorFlowObject + { + private Graph _graph; + + public GraphDeviceContext(Graph graph, string device_name) + { + _graph = graph; + } + + public void __enter__() + { + + } + + public void __exit__() + { + + } + + public void Dispose() + { + + } + } +} diff --git a/src/TensorFlowNET.Core/Training/Saving/ResourceVariableSaveable.cs b/src/TensorFlowNET.Core/Training/Saving/ResourceVariableSaveable.cs index 2d23a325f..35d982cd3 100644 --- a/src/TensorFlowNET.Core/Training/Saving/ResourceVariableSaveable.cs +++ b/src/TensorFlowNET.Core/Training/Saving/ResourceVariableSaveable.cs @@ -42,16 +42,20 @@ public ResourceVariableSaveable(BaseResourceVariable var, string slice_spec, str _var_device = var.Device; _var_shape = var.shape; - Tensor _read_variable_closure(BaseResourceVariable v) + Tensor? _read_variable_closure(BaseResourceVariable v) { - tf.device(v.Device); - if(tf.Context.executing_eagerly() && !((bool)v.is_initialized().numpy())) + return tf_with(ops.device(v.Device), _ => { - return null; - } - var x = v.read_value_no_copy(); - tf.device("/device:CPU:0"); - return array_ops.identity(x); + if (tf.Context.executing_eagerly() && !((bool)v.is_initialized().numpy())) + { + return null; + } + var x = v.read_value_no_copy(); + return tf_with(ops.device("/device:CPU:0"), __ => + { + return array_ops.identity(x); + }); + }); } this.handle_op = var.Handle; diff --git a/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs b/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs index dc9e5ba56..65f5a01bf 100644 --- a/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs +++ b/src/TensorFlowNET.Core/Training/Saving/SavedModel/loader.cs @@ -412,8 +412,10 @@ private void _restore_checkpoint() { var variables_path = SavedModelUtils.get_variables_path(_export_dir); var saver = new TrackableSaver(new ObjectGraphView(get(0))); - tf.device("CPU"); - saver.FilePrefixPlaceHolder = constant_op.constant(variables_path); + tf_with(ops.device("CPU"), _ => + { + saver.FilePrefixPlaceHolder = constant_op.constant(variables_path); + }); LoadStatus load_status; if (_save_options.allow_partial_checkpoint) { @@ -598,14 +600,16 @@ private void _add_object_graph_edges(SavedObject proto, int node_id) if (load_with_device) { - tf.device(saved_device); - return (new UninitializedVariable( - shape: new Shape(proto.Shape.Dim.Select(x => (int)x.Size).ToArray()), - dtype: (TF_DataType)proto.Dtype, - name: name, - trainable: trainable, - aggregation: aggregation - ), setattr); + return tf_with(ops.device(saved_device), _ => + { + return (new UninitializedVariable( + shape: new Shape(proto.Shape.Dim.Select(x => (int)x.Size).ToArray()), + dtype: (TF_DataType)proto.Dtype, + name: name, + trainable: trainable, + aggregation: aggregation + ), setattr); + }); } else { diff --git a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs index 9427b87ff..cd972adad 100644 --- a/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs +++ b/src/TensorFlowNET.Core/Variables/BaseResourceVariable.cs @@ -266,9 +266,11 @@ public override (IDictionary, IDictionary) BaseResourceVariable new_variable; if (save_options.experimental_variable_policy.save_variable_devices()) { - tf.device(this.Device); Debug.Assert(this is ResourceVariable); - new_variable = resource_variable_ops.copy_to_graph_uninitialized((ResourceVariable)this); + new_variable = tf_with(ops.device(this.Device), _ => + { + return resource_variable_ops.copy_to_graph_uninitialized((ResourceVariable)this); + }); } else { diff --git a/src/TensorFlowNET.Core/Variables/UninitializedVariable.cs b/src/TensorFlowNET.Core/Variables/UninitializedVariable.cs index 6c0349950..c12f84505 100644 --- a/src/TensorFlowNET.Core/Variables/UninitializedVariable.cs +++ b/src/TensorFlowNET.Core/Variables/UninitializedVariable.cs @@ -49,9 +49,12 @@ public UninitializedVariable( { tf_with(ops.name_scope("Read"), _ => { - tf.device(handle.Device); - var value = gen_resource_variable_ops.read_variable_op(handle, dtype); - // _maybe_set_handle_data(dtype, handle, value) + var value = tf_with(ops.device(handle.Device), _ => + { + var result = gen_resource_variable_ops.read_variable_op(handle, dtype); + // TODO(Rinne): _maybe_set_handle_data(dtype, handle, value) + return result; + }); _graph_element = value; }); ops.add_to_collection(ops.GraphKeys.GLOBAL_VARIABLES_, this); diff --git a/src/TensorFlowNET.Core/ops.cs b/src/TensorFlowNET.Core/ops.cs index 48d8b5c5f..1f83f9ee5 100644 --- a/src/TensorFlowNET.Core/ops.cs +++ b/src/TensorFlowNET.Core/ops.cs @@ -577,6 +577,23 @@ public static void dismantle_graph(Graph graph) } + public static ITensorFlowObject device(string device_name) + { + if (tf.Context.executing_eagerly()) + { + return tf.Context.device(device_name); + } + //else if (ops.executing_eagerly_outside_functions()) + //{ + // throw new NotImplementedException(); + //} + else + { + return get_default_graph().device(device_name); + } + // TODO(Rinne): deal with `ops.executing_eagerly_outside_functions()`. + } + public class NullContextManager: IDisposable { public void Dispose()