Skip to content
This repository was archived by the owner on Jul 7, 2023. It is now read-only.

Commit cb76e97

Browse files
hbrylkowskiafrozenator
authored andcommitted
Documentation for creating own model (#1589)
* Update mscoco.py * docs for adding new model * corrected contributing link
1 parent c394f62 commit cb76e97

File tree

1 file changed

+94
-3
lines changed

1 file changed

+94
-3
lines changed

docs/new_model.md

+94-3
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,103 @@ version](https://badge.fury.io/py/tensor2tensor.svg)](https://badge.fury.io/py/t
55
[![GitHub
66
Issues](https://img.shields.io/github/issues/tensorflow/tensor2tensor.svg)](https://github.com/tensorflow/tensor2tensor/issues)
77
[![Contributions
8-
welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](CONTRIBUTING.md)
8+
welcome](https://img.shields.io/badge/contributions-welcome-brightgreen.svg)](../CONTRIBUTING.md)
99
[![Gitter](https://img.shields.io/gitter/room/nwjs/nw.js.svg)](https://gitter.im/tensor2tensor/Lobby)
1010
[![License](https://img.shields.io/badge/License-Apache%202.0-brightgreen.svg)](https://opensource.org/licenses/Apache-2.0)
1111

1212
Here we show how to create your own model in T2T.
1313

14-
## The T2TModel class
14+
## The T2TModel class - abstract base class for models
1515

16-
TODO: complete.
16+
`T2TModel` has three typical usages:
17+
18+
1. Estimator: The method `make_estimator_model_fn` builds a `model_fn` for
19+
the tf.Estimator workflow of training, evaluation, and prediction.
20+
It performs the method `call`, which performs the core computation,
21+
followed by `estimator_spec_train`, `estimator_spec_eval`, or
22+
`estimator_spec_predict` depending on the tf.Estimator mode.
23+
2. Layer: The method `call` enables `T2TModel` to be used a callable by
24+
itself. It calls the following methods:
25+
26+
* `bottom`, which transforms features according to `problem_hparams`' input
27+
and target `Modality`s;
28+
* `body`, which takes features and performs the core model computation to
29+
return output and any auxiliary loss terms;
30+
* `top`, which takes features and the body output, and transforms them
31+
according to `problem_hparams`' input and target `Modality`s to return
32+
the final logits;
33+
* `loss`, which takes the logits, forms any missing training loss, and sums
34+
all loss terms.
35+
3. Inference: The method `infer` enables `T2TModel` to make sequence
36+
predictions by itself.
37+
38+
39+
## Creating your own model
40+
41+
1. Create class that extends T2TModel
42+
in this example it will be a copy of existing basic fully connected network:
43+
```python
44+
from tensor2tensor.utils import t2t_model
45+
46+
class MyFC(t2t_model.T2TModel):
47+
pass
48+
```
49+
50+
2. Implement body method:
51+
```python
52+
class MyFC(t2t_model.T2TModel):
53+
def body(self, features):
54+
hparams = self.hparams
55+
x = features["inputs"]
56+
shape = common_layers.shape_list(x)
57+
x = tf.reshape(x, [-1, shape[1] * shape[2] * shape[3]]) # Flatten input as in T2T they are all 4D vectors
58+
for i in range(hparams.num_hidden_layers): # create layers
59+
x = tf.layers.dense(x, hparams.hidden_size, name="layer_%d" % i)
60+
x = tf.nn.dropout(x, keep_prob=1.0 - hparams.dropout)
61+
x = tf.nn.relu(x)
62+
return tf.expand_dims(tf.expand_dims(x, axis=1), axis=1) # 4D For T2T.
63+
```
64+
65+
method signature:
66+
* Args:
67+
* features: dict of str to Tensor, where each Tensor has shape [batch_size,
68+
..., hidden_size]. It typically contains keys `inputs` and `targets`.
69+
70+
* Returns one of:
71+
* output: Tensor of pre-logit activations with shape [batch_size, ...,
72+
hidden_size].
73+
* losses: Either single loss as a scalar, a list, a Tensor (to be averaged),
74+
or a dictionary of losses. If losses is a dictionary with the key
75+
"training", losses["training"] is considered the final training
76+
loss and output is considered logits; self.top and self.loss will
77+
be skipped.
78+
79+
3. Register your model
80+
```python
81+
from tensor2tensor.utils import registry
82+
83+
@registry.register_model
84+
class MyFC(t2t_model.T2TModel):
85+
# ...
86+
```
87+
88+
3. Use it with t2t tools as any other model
89+
90+
Have in mind that names are translated from camel case to snake_case `MyFC` -> `my_fc`
91+
and that you need to point t2t to directory containing your model with `t2t_usr_dir` switch.
92+
For example if you want to train model on gcloud with 1 GPU worker on IMDB sentiment task you can run your model
93+
by executing following command from your model class directory.
94+
95+
```bash
96+
t2t-trainer \
97+
--model=my_fc \
98+
--t2t_usr_dir=.
99+
--cloud_mlengine --worker_gpu=1 \
100+
--generate_data \
101+
--data_dir='gs://data' \
102+
--output_dir='gs://out' \
103+
--problem=sentiment_imdb \
104+
--hparams_set=basic_fc_small \
105+
--train_steps=10000 \
106+
--eval_steps=10 \
107+
```

0 commit comments

Comments
 (0)