Skip to content

Fix the error when saving model with GPU. #1023

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions src/TensorFlowNET.Core/Checkpoint/SaveUtilV1.cs
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,11 @@ public static (IList<MySaveableObject>, IDictionary<string, IDictionary<string,
var g = to_graph.as_default();
var (named_saveable_objects, graph_proto, _, registered_savers) = serialize_gathered_objects(graph_view,
object_map, call_with_mapped_captures, saveables_cache);
tf.device("/cpu:0");
var object_graph_tensor = constant_op.constant(graph_proto.ToByteArray());
var object_graph_tensor = tf_with(ops.device("/cpu:0"), _ =>
{
// TODO(Rinne): locate the error that causes transferring TF_STRING to this function throws an exception.
return constant_op.constant(graph_proto.ToByteArray());
});
named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY));
g.Exit();
return (named_saveable_objects, registered_savers);
Expand All @@ -65,8 +68,10 @@ public static (IList<MySaveableObject>, IDictionary<string, IDictionary<string,
{
var (named_saveable_objects, graph_proto, _, registered_savers) = serialize_gathered_objects(graph_view,
object_map, call_with_mapped_captures, saveables_cache);
tf.device("/cpu:0");
var object_graph_tensor = constant_op.constant(graph_proto.ToString(), TF_DataType.TF_STRING);
var object_graph_tensor = tf_with(ops.device("/cpu:0"), _ =>
{
return constant_op.constant(graph_proto.ToString());
});
named_saveable_objects.Add(new NoRestoreSaveable(object_graph_tensor, Trackable.Constants.OBJECT_GRAPH_PROTO_KEY));
return (named_saveable_objects, registered_savers);
}
Expand Down
20 changes: 13 additions & 7 deletions src/TensorFlowNET.Core/Checkpoint/checkpoint.cs
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,10 @@ public TrackableSaver(ObjectGraphView graph_view)

if(object_graph_tensor is null)
{
tf.device("/cpu:0");
object_graph_tensor = constant_op.constant(graph_proto.ToByteArray());
tf_with(ops.device("/cpu:0"), _ =>
{
object_graph_tensor = constant_op.constant(graph_proto.ToByteArray());
});
}
else
{
Expand Down Expand Up @@ -230,22 +232,26 @@ public LoadStatus restore(string? save_path, CheckpointOptions? options = null)
Tensor object_graph_string = reader.GetTensor(Trackable.Constants.OBJECT_GRAPH_PROTO_KEY, dtype: TF_DataType.TF_STRING);

Dictionary<Tensor, string> file_prefix_feed_dict;
Tensor file_prefix_tensor;
Tensor file_prefix_tensor = null;
if (graph_building)
{
if(_file_prefix_placeholder is null)
{
tf.device("/cpu:0");
_file_prefix_placeholder = constant_op.constant("model");
_file_prefix_placeholder = tf_with(ops.device("/cpu:0"), _ =>
{
return constant_op.constant("model");
});
}
file_prefix_tensor = _file_prefix_placeholder;
file_prefix_feed_dict = new();
file_prefix_feed_dict[_file_prefix_placeholder] = save_path;
}
else
{
tf.device("/cpu:0");
file_prefix_tensor = constant_op.constant(save_path);
file_prefix_tensor = tf_with(ops.device("/cpu:0"), _ =>
{
return constant_op.constant(save_path);
});
file_prefix_feed_dict = null;
}
TrackableObjectGraph object_graph_proto = new();
Expand Down
129 changes: 74 additions & 55 deletions src/TensorFlowNET.Core/Checkpoint/functional_saver.cs
Original file line number Diff line number Diff line change
Expand Up @@ -211,9 +211,11 @@ public IDictionary<string, IDictionary<string, Tensor>> restore(Tensor file_pref

string restore_device = string.IsNullOrEmpty(options.experimental_io_device) ? "cpu:0": options.experimental_io_device!;

// tf python has code `with ops.device(restore_device):` here.
tf.device(restore_device); // may be risky.
var restored_tensors = gen_ops.restore_v2(file_prefix, tensor_names.ToArray(), slice_specs.ToArray(), tensor_dtypes.ToArray());
Tensor[] restored_tensors = null;
tf_with(ops.device(restore_device), _ =>
{
restored_tensors = gen_ops.restore_v2(file_prefix, tensor_names.ToArray(), slice_specs.ToArray(), tensor_dtypes.ToArray());
});

Dictionary<string, IDictionary<string, Tensor>> restored_tensor_dict = new();
int idx = 0;
Expand Down Expand Up @@ -338,11 +340,14 @@ public Operation save(Tensor file_prefix, CheckpointOptions? options= null)
options = new CheckpointOptions();
}

tf.device("CPU"); // may be risky.
var sharded_suffix = array_ops.where(gen_ops.regex_full_match(file_prefix, tf.constant(@"^s3://.*")),
Tensor tmp_checkpoint_prefix = null;
tf_with(ops.device("CPU"), _ =>
{
var sharded_suffix = array_ops.where(gen_ops.regex_full_match(file_prefix, tf.constant(@"^s3://.*")),
constant_op.constant(".part"), constant_op.constant("_temp/part"));
var tmp_checkpoint_prefix = gen_ops.string_join(new Tensor[] { file_prefix, sharded_suffix });
IDictionary<string, Tensor> registered_paths = _registered_savers.Keys.ToDictionary(x => x, x => registered_saver_filename(file_prefix, x));
tmp_checkpoint_prefix = gen_ops.string_join(new Tensor[] { file_prefix, sharded_suffix });
IDictionary<string, Tensor> registered_paths = _registered_savers.Keys.ToDictionary(x => x, x => registered_saver_filename(file_prefix, x));
});

Operation save_fn()
{
Expand All @@ -364,16 +369,24 @@ Operation save_fn()
var saver = pair.Value;
last_device = device;
// skip the extra process of device name because of lack of API.
tf.device(device);
var shard_prefix = sharded_filename(tmp_checkpoint_prefix, shard, num_shards_tensor);
Tensor shard_prefix = null;
tf_with(ops.device(device), _ =>
{
shard_prefix = sharded_filename(tmp_checkpoint_prefix, shard, num_shards_tensor);
});
saved_prefixes.Add(shard_prefix);
sharded_saves.Add(saver.save(shard_prefix, options));
tf_with(ops.device(device), _ =>
{
sharded_saves.Add(saver.save(shard_prefix, options));
});
}
using (var controller = ops.control_dependencies(sharded_saves.ToArray()))
{
string merge_device = string.IsNullOrEmpty(options.experimental_io_device) ? last_device : options.experimental_io_device;
tf.device(merge_device);
return gen_ops.merge_v2_checkpoints(saved_prefixes.ToArray(), tf.constant(file_prefix), delete_old_dirs: true);
return tf_with(ops.device(merge_device), _ =>
{
return gen_ops.merge_v2_checkpoints(saved_prefixes.ToArray(), tf.constant(file_prefix), delete_old_dirs: true);
});
}
}

Expand Down Expand Up @@ -407,54 +420,56 @@ IDictionary<string, Operation> restore_func()
{
var device = single_saver.Key;
var saver = single_saver.Value;
tf.device(device);
var restored_tensor_dict = saver.restore(file_prefix, options);

foreach(var pair in restored_tensor_dict)
tf_with(ops.device(device), _ =>
{
var checkpoint_key = pair.Key;
var slice_and_tensor = pair.Value;
foreach(var item in slice_and_tensor)
var restored_tensor_dict = saver.restore(file_prefix, options);

foreach (var pair in restored_tensor_dict)
{
var slice_spec = item.Key;
var tensor = item.Value;
var restore_fn = _keys_to_restore_fn[(checkpoint_key, slice_spec)];
var internal_dict = restore_fn_inputs.SetDefault(restore_fn, new Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>());
if (!string.IsNullOrEmpty(slice_spec))
var checkpoint_key = pair.Key;
var slice_and_tensor = pair.Value;
foreach (var item in slice_and_tensor)
{
if (!internal_dict.ContainsKey(checkpoint_key))
var slice_spec = item.Key;
var tensor = item.Value;
var restore_fn = _keys_to_restore_fn[(checkpoint_key, slice_spec)];
var internal_dict = restore_fn_inputs.SetDefault(restore_fn, new Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>>());
if (!string.IsNullOrEmpty(slice_spec))
{
Dictionary<string, Tensor> dict = new();
dict[slice_spec] = tensor;
internal_dict[checkpoint_key] = new Maybe<Tensor, IDictionary<string, Tensor>>(dict);
if (!internal_dict.ContainsKey(checkpoint_key))
{
Dictionary<string, Tensor> dict = new();
dict[slice_spec] = tensor;
internal_dict[checkpoint_key] = new Maybe<Tensor, IDictionary<string, Tensor>>(dict);
}
else
{
internal_dict[checkpoint_key].GetValue<IDictionary<string, Tensor>>()[slice_spec] = tensor;
}
}
else
{
internal_dict[checkpoint_key].GetValue<IDictionary<string, Tensor>>()[slice_spec] = tensor;
internal_dict[checkpoint_key] = new Maybe<Tensor, IDictionary<string, Tensor>>(tensor);
}
}
else
{
internal_dict[checkpoint_key] = new Maybe<Tensor, IDictionary<string, Tensor>>(tensor);
}
restore_fn_input_count[restore_fn]--;
restore_fn_input_count[restore_fn]--;

if (restore_fn_input_count[restore_fn] == 0)
{
Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> restored_tensors = new();
foreach(var input in restore_fn_inputs[restore_fn])
if (restore_fn_input_count[restore_fn] == 0)
{
restored_tensors[TrackableUtils.extract_local_name(input.Key)] = input.Value;
}
var ret = restore_fn.DynamicInvoke(restored_tensors);
if(ret is IDictionary<string, Operation>)
{
var dict = (IDictionary<string, Operation>)ret;
restore_ops = restore_ops.Concat(dict).ToDictionary(x => x.Key, x => x.Value);
Dictionary<string, Maybe<Tensor, IDictionary<string, Tensor>>> restored_tensors = new();
foreach (var input in restore_fn_inputs[restore_fn])
{
restored_tensors[TrackableUtils.extract_local_name(input.Key)] = input.Value;
}
var ret = restore_fn.DynamicInvoke(restored_tensors);
if (ret is IDictionary<string, Operation>)
{
var dict = (IDictionary<string, Operation>)ret;
restore_ops = restore_ops.Concat(dict).ToDictionary(x => x.Key, x => x.Value);
}
}
}
}
}
});
}

foreach(var item in _registered_savers)
Expand Down Expand Up @@ -500,21 +515,25 @@ public SaverDef to_proto()
private Tensor _traced_save(Tensor file_prefix)
{
var save_op = save(file_prefix);
tf.device("cpu:0");
using (ops.control_dependencies(new object[]{ save_op }))
return tf_with(ops.device("cpu:0"), _ =>
{
return array_ops.identity(file_prefix);
}
return tf_with(ops.control_dependencies(new object[] { save_op }), __ =>
{
return array_ops.identity(file_prefix);
});
});
}

private Tensor _traced_restore(Tensor file_prefix)
{
var restore_op = restore(file_prefix);
tf.device("cpu:0");
using (ops.control_dependencies(restore_op.Values.ToArray()))
return tf_with(ops.device("cpu:0"), _ =>
{
return array_ops.identity(file_prefix);
}
return tf_with(ops.control_dependencies(restore_op.Values.ToArray()), __ =>
{
return array_ops.identity(file_prefix);
});
});
}

public static MultiDeviceSaver from_saveables(IEnumerable<MySaveableObject> saveables, IDictionary<string, IDictionary<string, Trackable>>? registered_savers = null, bool call_with_mapped_captures = false)
Expand Down
73 changes: 73 additions & 0 deletions src/TensorFlowNET.Core/Contexts/Context.Device.cs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ limitations under the License.
using static Tensorflow.Binding;
using Google.Protobuf;
using Tensorflow.Device;
using Tensorflow.Exceptions;
using System.Collections.Generic;

namespace Tensorflow.Contexts
Expand All @@ -30,10 +31,30 @@ namespace Tensorflow.Contexts
/// </summary>
public sealed partial class Context
{
internal static Dictionary<(string, string), (string, DeviceSpec)> _device_parsing_cache = new();
internal List<LogicalDevice> _logical_devices = null;
internal List<string> _context_devices = null;

ContextDevicePlacementPolicy _device_policy;
bool _log_device_placement;
int _num_gpus;
Dictionary<PhysicalDevice, bool> _memory_growth_map = new Dictionary<PhysicalDevice, bool>();

public string DeviceName { get; set; } = "";
public DeviceSpec DeviceSpec { get; set; } = null;

internal List<string> Devices
{
get
{
if(_context_devices is null)
{
throw new AssertionError("Context must be initialized first.");
}
return _context_devices;
}
}

public void log_device_placement(bool enable)
{
if (_handle != null)
Expand Down Expand Up @@ -89,5 +110,57 @@ public PhysicalDevice[] list_physical_devices(string device_type = null)

return results.ToArray();
}

public EagerDeviceContext device(string name)
{
return new EagerDeviceContext(this, name);
}

internal void _set_device(string device_name, DeviceSpec device_spec)
{
DeviceSpec = device_spec;
DeviceName = device_name;
}

internal void _initialize_logical_devices()
{
List<LogicalDevice> logical_devices = new();
List<string> context_devices = new();
Status status = new();
var device_list = c_api.TFE_ContextListDevices(_handle, status);
status.Check(true);
try
{
this._num_gpus = 0;
string current_job = null;
int current_task = -1;
for(int i = 0; i < c_api.TF_DeviceListCount(device_list); i++)
{
var dev_name = c_api.TF_DeviceListName(device_list, i, status);
status.Check(true);
context_devices.Add(DeviceUtils.canonical_name(dev_name));
var spec = DeviceSpec.from_string(dev_name);
if(spec.Job == "localhost")
{
spec = spec.replace(job: null, replica: -1, task: -1);
}
logical_devices.Add(new LogicalDevice(spec.ToString(), spec.DeviceType));
var dev_type_memory = c_api.TF_DeviceListType(device_list, i, status);
var dev_type = c_api.StringPiece(dev_type_memory);
status.Check(true);
if(dev_type == "GPU" && spec.Job == current_job && spec.Task == current_task)
{
_num_gpus++;
}
}
}
finally
{
_logical_devices = logical_devices;
_context_devices = context_devices;
}
}
}

public record class LogicalDevice(string name, string device_type);
}
5 changes: 4 additions & 1 deletion src/TensorFlowNET.Core/Contexts/Context.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ public sealed partial class Context
public const int EAGER_MODE = 1;

int defaultExecutionMode = EAGER_MODE;
public string DeviceName { get; set; } = "";
public string ScopeName { get; set; } = "";
bool initialized = false;
ContextSwitchStack context_switches;
Expand Down Expand Up @@ -62,6 +61,8 @@ public void ensure_initialized()
if (initialized)
return;

Debug.Assert(_context_devices is null);

Config = MergeConfig();
FunctionCallOptions.Config = Config;
var config_str = Config.ToByteArray();
Expand All @@ -72,6 +73,7 @@ public void ensure_initialized()
c_api.TFE_ContextOptionsSetDevicePlacementPolicy(opts, _device_policy);
_handle = c_api.TFE_NewContext(opts, status);
status.Check(true);
_initialize_logical_devices();
initialized = true;
}

Expand Down Expand Up @@ -174,6 +176,7 @@ public void reset_context()
{
c_api.TFE_ContextClearCaches(_handle);
}
_device_parsing_cache.Clear();
}

public static implicit operator SafeContextHandle(Context ctx)
Expand Down
Loading