Skip to content

Commit

Permalink
Some experiments with TF, 3
Browse files Browse the repository at this point in the history
  • Loading branch information
osmr committed Oct 3, 2018
1 parent b161309 commit 5efbb9d
Show file tree
Hide file tree
Showing 4 changed files with 343 additions and 224 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ before_script:
# stop the build if there are Python syntax errors or undefined names
- flake8 . --count --select=E901,E999,F821,F822,F823 --show-source --statistics
# exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
- flake8 . --count --exit-zero --max-complexity=32 --max-line-length=127 --ignore=F403,F405,E126,E127 --exclude=./pytorch/pytorchcv/models/others,./tensorflow_ --statistics
- flake8 . --count --exit-zero --max-complexity=32 --max-line-length=127 --ignore=F403,F405,E126,E127 --exclude=./pytorch/pytorchcv/models/others,./tensorflow_/others --statistics
script:
- true # pytest --capture=sys # add others tests here
notifications:
Expand Down
66 changes: 32 additions & 34 deletions eval_tf.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import argparse
import tqdm
import time
import logging

import mxnet as mx
from tensorpack.predict import PredictConfig, FeedfreePredictor
from tensorpack.utils.stats import RatioCounter
from tensorpack.input_source import QueueInput, StagingInput

from common.logger_utils import initialize_logging
from tensorflow_.utils import prepare_tf_context, prepare_model, get_data_loader, calc_net_weight_count,\
validate
from tensorflow_.utils import prepare_tf_context, prepare_model, get_data


def parse_args():
Expand Down Expand Up @@ -78,29 +80,30 @@ def parse_args():


def test(net,
val_data,
batch_fn,
use_rec,
dtype,
ctx,
calc_weight_count=False,
session_init,
val_dataflow,
extended_log=False):
acc_top1 = mx.metric.Accuracy()
acc_top5 = mx.metric.TopKAccuracy(5)

pred_config = PredictConfig(
model=net,
session_init=session_init,
input_names=['input', 'label'],
output_names=['wrong-top1', 'wrong-top5']
)
err_top1 = RatioCounter()
err_top5 = RatioCounter()

tic = time.time()
err_top1_val, err_top5_val = validate(
acc_top1=acc_top1,
acc_top5=acc_top5,
net=net,
val_data=val_data,
batch_fn=batch_fn,
use_rec=use_rec,
dtype=dtype,
ctx=ctx)
if calc_weight_count:
weight_count = calc_net_weight_count(net)
logging.info('Model: {} trainable parameters'.format(weight_count))
pred = FeedfreePredictor(pred_config, StagingInput(QueueInput(val_dataflow), device='/gpu:0'))
for _ in tqdm.trange(val_dataflow.size()):
err_top1_val, err_top5_val = pred()
batch_size = err_top1_val.shape[0]
err_top1.feed(err_top1_val.sum(), batch_size)
err_top5.feed(err_top5_val.sum(), batch_size)

err_top1_val = err_top1.ratio
err_top5_val = err_top5.ratio

if extended_log:
logging.info('Test: err-top1={top1:.4f} ({top1})\terr-top5={top5:.4f} ({top5})'.format(
top1=err_top1_val, top5=err_top5_val))
Expand All @@ -125,25 +128,20 @@ def main():
num_gpus=args.num_gpus,
batch_size=args.batch_size)

net = prepare_model(
net, inputs_desc = prepare_model(
model_name=args.model,
pretrained_model_file_path=args.resume.strip())

train_data, val_data, batch_fn = get_data_loader(
data_dir=args.data_dir,
val_dataflow = get_data(
is_train=False,
batch_size=batch_size,
num_workers=args.num_workers)
data_dir_path=args.data_dir)

assert (args.use_pretrained or args.resume.strip())
test(
net=net,
val_data=val_data,
batch_fn=batch_fn,
use_rec=args.use_rec,
dtype=args.dtype,
# ctx=ctx,
# calc_weight_count=(not log_file_exist),
calc_weight_count=True,
session_init=inputs_desc,
val_dataflow=val_dataflow,
extended_log=True)


Expand Down
Loading

0 comments on commit 5efbb9d

Please sign in to comment.