From aa159dbad47e17416fab3fd053aae09dae1b1a36 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A1ndor=20Bal=C3=A1zs?= Date: Sat, 14 Aug 2021 09:27:13 +0200 Subject: [PATCH 01/16] add redo matching to forntend, add wandb to collect stats, refactor some parts --- .gitignore | 2 + policy/frontend/src/components/Direction.tsx | 6 +- policy/frontend/src/pages/ModelsPage.tsx | 4 +- policy/frontend/src/pages/SessionPage.tsx | 31 +++- policy/frontend/src/pages/TrainPage.tsx | 25 ++- policy/frontend/src/utils/useProgress.ts | 10 +- policy/openbot/associate_frames.py | 160 +++++++++++++------ policy/openbot/dataloader.py | 65 +++++--- policy/openbot/models.py | 2 +- policy/openbot/server/api.py | 6 +- policy/openbot/server/dataset.py | 14 +- policy/openbot/train.py | 86 ++++++---- policy/policy_learning.ipynb | 2 +- 13 files changed, 285 insertions(+), 128 deletions(-) diff --git a/.gitignore b/.gitignore index 62bb9b0b7..cb9f3a95d 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,8 @@ composer.lock /phpunit.xml .DS_Store Thumbs.db +esphome/ +wandb/ # IDE's .idea/ diff --git a/policy/frontend/src/components/Direction.tsx b/policy/frontend/src/components/Direction.tsx index 9b05c2817..463ebe142 100644 --- a/policy/frontend/src/components/Direction.tsx +++ b/policy/frontend/src/components/Direction.tsx @@ -25,9 +25,9 @@ export function Direction({left, right}: { left: number; right: number }) {
{name}
- {distanceInWordsToNow(new Date(mtime * 1000).toISOString())} + {distanceInWordsToNow(new Date(mtime * 1000))}
@@ -79,7 +79,7 @@ function ModelDetails() { - + diff --git a/policy/frontend/src/pages/SessionPage.tsx b/policy/frontend/src/pages/SessionPage.tsx index 63927f2a2..b34083b43 100644 --- a/policy/frontend/src/pages/SessionPage.tsx +++ b/policy/frontend/src/pages/SessionPage.tsx @@ -1,8 +1,19 @@ import {styled} from 'goober'; -import {useEffect, useRef, useState} from 'react'; +import {useCallback, useEffect, useRef, useState} from 'react'; import {useHotkeys} from 'react-hotkeys-hook'; import {useRouteMatch} from 'react-router-dom'; -import {FlexboxGrid, Icon, IconButton, InputNumber, Loader, Message, Panel, SelectPicker, Slider} from 'rsuite'; +import { + Dropdown, + FlexboxGrid, + Icon, + IconButton, + InputNumber, + Loader, + Message, + Panel, + SelectPicker, + Slider, +} from 'rsuite'; import {Direction} from 'src/components/Direction'; import {formatTime} from 'src/utils/formatTime'; import {Session} from 'src/utils/useDatasets'; @@ -25,7 +36,7 @@ export function SessionPage() { return <>

{match.params.path}

- {session ? : } + {session ? : } ; @@ -48,7 +59,7 @@ const PreviewCont = styled('div')` } `; -function SessionComp({session}: {session: Session}) { +function SessionComp({session, reload}: {session: Session, reload: () => any}) { const max = session.ctrl.length - 1; const models = useModels(); const [playing, togglePlaying] = useToggle(false); @@ -63,6 +74,10 @@ function SessionComp({session}: {session: Session}) { useHotkeys('space', togglePlaying); useHotkeys('right', () => setCurrent(c => c === max ? 0 : c + 1), [max]); useHotkeys('left', () => setCurrent(c => c === 0 ? max : c - 1), [max]); + const redoMatching = useCallback(async () => { + await jsonRpc('redoMatching', session.path); + reload(); + }, [reload, session.path]); const prediction = usePrediction(session.path, model, indicator, current); const [predLeft, predRight] = prediction.value[current] || []; @@ -92,6 +107,10 @@ function SessionComp({session}: {session: Session}) { onChange={setIndicator} searchable={false} /> +
+ + Redo matching + @@ -99,10 +118,12 @@ function SessionComp({session}: {session: Session}) { + {left} {right}
Indicator: {ind}
+ {predLeft} {predRight}
{prediction.pending > 0 && }
@@ -146,7 +167,7 @@ function usePrediction(path: string, model: string | null, indicator: string | n batch++; } useEffect(() => { - if (valid.current[batch]) { + if (!model || valid.current[batch]) { return; } valid.current[batch] = true; diff --git a/policy/frontend/src/pages/TrainPage.tsx b/policy/frontend/src/pages/TrainPage.tsx index 016a60279..002275f8c 100644 --- a/policy/frontend/src/pages/TrainPage.tsx +++ b/policy/frontend/src/pages/TrainPage.tsx @@ -1,3 +1,4 @@ +import {differenceInMinutes} from 'date-fns'; import NoSleep from 'nosleep.js'; import {useEffect} from 'react'; import {Button, Panel, Progress} from 'rsuite'; @@ -36,8 +37,11 @@ function TrainProgress({state, clear}: { state: ProgressState, clear: () => any noSleep.disable(); } }, [active]); + const now = new Date(); + const end = predictEndDate(state, now); return <> +

Current step: {state.message}

Current epoch: @@ -46,7 +50,9 @@ function TrainProgress({state, clear}: { state: ProgressState, clear: () => any Training:
-
{state.message}
+
Full time: {differenceInMinutes(end, state.startTime)} minutes
+
Elapsed time: {differenceInMinutes(now, state.startTime)} minutes
+
Remaining time: {differenceInMinutes(end, now)} minutes
{active ? ( @@ -56,7 +62,7 @@ function TrainProgress({state, clear}: { state: ProgressState, clear: () => any
- + @@ -71,10 +77,23 @@ function TrainProgress({state, clear}: { state: ProgressState, clear: () => any preview thumbnails )} - {state.model && ( + {!!state.model && ( + + preview thumbnails + + )} + {state.status === 'success' && ( preview thumbnails )} } + +function predictEndDate(state: ProgressState, now: Date) { + const start = state.startTime.getTime(); + const elapsed = now.getTime() - start; + const fullTime = elapsed / state.percent * 100; + + return new Date(start + fullTime); +} diff --git a/policy/frontend/src/utils/useProgress.ts b/policy/frontend/src/utils/useProgress.ts index 16f4539ae..d305e8738 100644 --- a/policy/frontend/src/utils/useProgress.ts +++ b/policy/frontend/src/utils/useProgress.ts @@ -14,6 +14,7 @@ export interface Hyperparametes { } export interface ProgressState { + startTime: Date; status: 'success' | 'fail' | 'active' | undefined; epoch: number; percent: number; @@ -26,6 +27,7 @@ export interface ProgressState { } const defaultState: ProgressState = { + startTime: new Date(), status: undefined, epoch: 0, percent: 0, @@ -51,6 +53,7 @@ function progressReducer(state: ProgressState, msg: any): ProgressState { switch (msg.event) { case 'started': return { + startTime: new Date(), status: 'active', epoch: 0, percent: 0, @@ -62,6 +65,11 @@ function progressReducer(state: ProgressState, msg: any): ProgressState { ...state, rnd: Date.now(), }; + case 'model': + return { + ...state, + model: msg.payload, + }; case 'logs': return { ...state, @@ -98,11 +106,11 @@ function progressReducer(state: ProgressState, msg: any): ProgressState { ...state, rnd: Date.now(), status: 'success', - model: msg.payload.model, message: 'Done', }; case 'clear': return { + startTime: new Date(), status: undefined, epoch: 0, percent: 0, diff --git a/policy/openbot/associate_frames.py b/policy/openbot/associate_frames.py index 031b62f20..480050e2b 100644 --- a/policy/openbot/associate_frames.py +++ b/policy/openbot/associate_frames.py @@ -38,13 +38,11 @@ Modified and extended by Matthias Mueller - Intel Intelligent Systems Lab - 2020 The controls are event-based and not synchronized to the frames. This script matches the control signals to frames. -Specifically, if there was no control signal event within some threshold (default: 1ms), the last control signal before the frame is used. +Specifically, if there was no control signal event within some threshold (default: 1ms), +the last control signal before the frame is used. """ -import argparse -import sys import os -import numpy from . import utils @@ -65,13 +63,19 @@ def read_file_list(filename): """ f = open(filename) - header = f.readline() #discard header + # discard header + header = f.readline() data = f.read() - lines = data.replace(","," ").replace("\t"," ").split("\n") - data = [[v.strip() for v in line.split(" ") if v.strip()!=""] for line in lines if len(line)>0 and line[0]!="#"] - data = [(int(l[0]),l[1:]) for l in data if len(l)>1] + lines = data.replace(",", " ").replace("\t", " ").split("\n") + data = [ + [v.strip() for v in line.split(" ") if v.strip() != ""] + for line in lines + if len(line) > 0 and line[0] != "#" + ] + data = [(int(line[0]), line[1:]) for line in data if len(line) > 1] return dict(data) + def associate(first_list, second_list, max_offset): """ Associate two dictionaries of (stamp,data). As the time stamps never match exactly, we aim @@ -89,89 +93,143 @@ def associate(first_list, second_list, max_offset): """ first_keys = list(first_list) second_keys = list(second_list) - potential_matches = [(b-a, a, b) - for a in first_keys - for b in second_keys - if (b-a) < max_offset] #Control before image or within max_offset - potential_matches.sort(reverse = True) + potential_matches = [ + (b - a, a, b) for a in first_keys for b in second_keys if (b - a) < max_offset + ] # Control before image or within max_offset + potential_matches.sort(reverse=True) matches = [] for diff, a, b in potential_matches: if a in first_keys and b in second_keys: - first_keys.remove(a) #Remove frame that was assigned - matches.append((a, b)) #Append tuple + first_keys.remove(a) # Remove frame that was assigned + matches.append((a, b)) # Append tuple matches.sort() return matches -def match_frame_ctrl_cmd(data_dir, datasets, max_offset, redo_matching=False, remove_zeros=True): +def match_frame_ctrl_cmd( + data_dir, datasets, max_offset, redo_matching=False, remove_zeros=True +): frames = [] for dataset in datasets: for folder in utils.list_dirs(os.path.join(data_dir, dataset)): session_dir = os.path.join(data_dir, dataset, folder) - frame_list = match_frame_session(session_dir, max_offset, redo_matching, remove_zeros) + frame_list = match_frame_session( + session_dir, max_offset, redo_matching, remove_zeros + ) for timestamp in list(frame_list): frames.append(frame_list[timestamp][0]) return frames -def match_frame_session(session_dir, max_offset, redo_matching=False, remove_zeros=True): + +def match_frame_session( + session_dir, max_offset, redo_matching=False, remove_zeros=True +): sensor_path = os.path.join(session_dir, "sensor_data") img_path = os.path.join(session_dir, "images") - print("Processing folder %s" %(session_dir)) - if (not redo_matching and os.path.isfile(os.path.join(sensor_path,"matched_frame_ctrl.txt"))): + print("Processing folder %s" % (session_dir)) + if not redo_matching and os.path.isfile( + os.path.join(sensor_path, "matched_frame_ctrl.txt") + ): print(" Frames and controls already matched.") else: - #Match frames with control signals - frame_list = read_file_list(os.path.join(sensor_path,"rgbFrames.txt")) + # Match frames with control signals + frame_list = read_file_list(os.path.join(sensor_path, "rgbFrames.txt")) if len(frame_list) == 0: raise Exception("Empty rgbFrames.txt") - ctrl_list = read_file_list(os.path.join(sensor_path,"ctrlLog.txt")) + ctrl_list = read_file_list(os.path.join(sensor_path, "ctrlLog.txt")) if len(ctrl_list) == 0: raise Exception("Empty ctrlLog.txt") matches = associate(frame_list, ctrl_list, max_offset) - with open(os.path.join(sensor_path,"matched_frame_ctrl.txt"), 'w') as f: + with open(os.path.join(sensor_path, "matched_frame_ctrl.txt"), "w") as f: f.write("timestamp (frame),time_offset (ctrl-frame),frame,left,right\n") - for a,b in matches: - f.write("%d %d %s %s \n"%(a,b-a," ".join(frame_list[a]), " ".join(ctrl_list[b]))) + for a, b in matches: + ctrl = ctrl_list[b] + f.write( + "%d,%d,%s,%s\n" + % ( + a, + b - a, + ",".join(frame_list[a]), + ",".join(ctrl), + ) + ) print(" Frames and controls matched.") - if (not redo_matching and os.path.isfile(os.path.join(sensor_path,"matched_frame_ctrl_cmd.txt"))): + if not redo_matching and os.path.isfile( + os.path.join(sensor_path, "matched_frame_ctrl_cmd.txt") + ): print(" Frames and commands already matched.") else: - #Match frames and controls with indicator commands - frame_list = read_file_list(os.path.join(sensor_path,"matched_frame_ctrl.txt")) + # Match frames and controls with indicator commands + frame_list = read_file_list(os.path.join(sensor_path, "matched_frame_ctrl.txt")) if len(frame_list) == 0: raise Exception("Empty matched_frame_ctrl.txt") - cmd_list = read_file_list(os.path.join(sensor_path,"indicatorLog.txt")) - #Set indicator signal to 0 for initial frames - if len(cmd_list) == 0 or sorted(frame_list)[0]0 and line[0]!="#"] - #Tuples containing id: framepath and label: left,right,cmd - data = [(l[1],l[2:]) for l in data if len(l)>1] + lines = ( + data.replace(",", " ") + .replace("\\", "/") + .replace("\r", "") + .replace("\t", " ") + .split("\n") + ) + data = [ + [v.strip() for v in line.split(" ") if v.strip() != ""] + for line in lines + if len(line) > 0 and line[0] != "#" + ] + # Tuples containing id: framepath and label: left,right,cmd + data = [(line[1], line[2:]) for line in data if len(line) > 1] corpus.extend(data) return dict(corpus) # build a lookup table to get the frame index for the label - def lookup_table (self): + def lookup_table(self): table = tf.lookup.StaticHashTable( initializer=tf.lookup.KeyValueTensorInitializer( keys=list(self.labels.keys()), values=list(i for i in range(len(self.labels.keys()))), ), default_value=tf.constant(-1), - name="frame_index" + name="frame_index", ) - return table + return table def get_label(self, file_path): index = self.index_table.lookup(file_path) - return self.cmd_values[index], self.label_values[index]/255 \ No newline at end of file + return self.cmd_values[index], self.label_values[index] / 255 diff --git a/policy/openbot/models.py b/policy/openbot/models.py index 48f13dfb1..3650cc55d 100644 --- a/policy/openbot/models.py +++ b/policy/openbot/models.py @@ -72,7 +72,7 @@ def pilot_net(img_width, img_height, bn=False): bn=bn) # fuse input MLP and CNN - combinedInput = tf.keras.layers.concatenate([mlp.input, cnn.output]) + combinedInput = tf.keras.layers.concatenate([mlp.output, cnn.output]) # output MLP x = tf.keras.layers.Dense(50, activation="relu")(combinedInput) diff --git a/policy/openbot/server/api.py b/policy/openbot/server/api.py index 33cb58403..2154eeddd 100644 --- a/policy/openbot/server/api.py +++ b/policy/openbot/server/api.py @@ -1,5 +1,4 @@ import asyncio -import glob import os import shutil import threading @@ -9,12 +8,12 @@ import numpy as np from numpyencoder import NumpyEncoder -from .dataset import get_dataset_list, get_dir_info, get_info +from .dataset import get_dataset_list, get_dir_info, get_info, redoMatching from .models import get_model_info, get_models, getModelFiles, publishModel, deleteModelFile from .preview import handle_preview from .prediction import getPrediction from .upload import handle_file_upload -from .. import base_dir, dataset_dir, models_dir +from .. import base_dir, dataset_dir from ..train import CancelledException, Hyperparameters, MyCallback, start_train, create_tfrecord event_cancelled = threading.Event() @@ -82,6 +81,7 @@ async def init_api(app: web.Application): ("", getSession), ("", moveSession), ("", deleteSession), + ("", redoMatching), ("", start), ("", stop), ) diff --git a/policy/openbot/server/dataset.py b/policy/openbot/server/dataset.py index f41719d8e..6128e3789 100644 --- a/policy/openbot/server/dataset.py +++ b/policy/openbot/server/dataset.py @@ -45,8 +45,8 @@ def get_info(path, basename=None): if not os.path.isdir(real_path): return None - isSession = is_session(real_path) - if isSession: + is_session = os.path.isdir(real_path + "/images") + if is_session: try: max_offset = 1e3 frames = associate_frames.match_frame_session( @@ -72,7 +72,7 @@ def get_info(path, basename=None): return { "path": "/" + path, "name": basename, - "is_session": isSession, + "is_session": is_session, "ctrl": ctrl, "seconds": seconds, "error": error, @@ -86,14 +86,16 @@ def get_info(path, basename=None): return { "path": "/" + path, "name": basename, - "is_session": isSession, + "is_session": is_session, "files": file_count - dir_count, "dirs": dir_count, } -def is_session(path): - return os.path.isdir(path + "/images") +def redoMatching(path): + max_offset = 1e3 + associate_frames.match_frame_session(dataset_dir + path, max_offset, True, True) + return True def count_lines(path): diff --git a/policy/openbot/train.py b/policy/openbot/train.py index 4cad0beaa..59f152188 100644 --- a/policy/openbot/train.py +++ b/policy/openbot/train.py @@ -6,6 +6,8 @@ import matplotlib.pyplot as plt import numpy as np import tensorflow as tf +import wandb +from wandb.keras import WandbCallback from . import ( associate_frames, @@ -87,6 +89,8 @@ def __init__(self, params: Hyperparameters): self.test_data_dir = "" self.train_datasets = [] self.test_datasets = [] + self.redo_matching = False + self.remove_zeros = True self.image_count_train = 0 self.image_count_test = 0 self.train_ds = None @@ -136,13 +140,13 @@ def on_batch_end(self, batch, logs=None): self.broadcast( "progress", dict( - epoch=int(100 * self.step / steps), - train=int(100 * (self.epoch * steps + self.step) / (epochs * steps)), + epoch=round(100 * self.step / steps, 1), + train=round(100 * (self.epoch * steps + self.step) / (epochs * steps), 1), ), ) -def process_data(tr: Training, redo_matching=False, remove_zeros=True): +def process_data(tr: Training): tr.train_datasets = utils.list_dirs(tr.train_data_dir) tr.test_datasets = utils.list_dirs(tr.test_data_dir) @@ -155,15 +159,15 @@ def process_data(tr: Training, redo_matching=False, remove_zeros=True): tr.train_data_dir, tr.train_datasets, max_offset, - redo_matching=redo_matching, - remove_zeros=remove_zeros, + redo_matching=tr.redo_matching, + remove_zeros=tf.remove_zeros, ) test_frames = associate_frames.match_frame_ctrl_cmd( tr.test_data_dir, tr.test_datasets, max_offset, - redo_matching=redo_matching, - remove_zeros=remove_zeros, + redo_matching=tr.redo_matching, + remove_zeros=tf.remove_zeros, ) tr.image_count_train = len(train_frames) @@ -194,7 +198,7 @@ def process_test_sample(features): label = [features["left"], features["right"]] return (image, cmd), label - train_dataset = ( + train_dataset = ( tf.data.TFRecordDataset(tr.train_data_dir, num_parallel_reads=AUTOTUNE) .map(tfrecord_utils.parse_tfrecord_fn, num_parallel_calls=AUTOTUNE) .map(process_train_sample, num_parallel_calls=AUTOTUNE) @@ -231,7 +235,7 @@ def process_test_sample(features): .prefetch(AUTOTUNE) ) - tr.test_ds = ( + tr.test_ds = ( test_dataset.batch(tr.hyperparameters.TEST_BATCH_SIZE) .prefetch(AUTOTUNE) ) @@ -306,14 +310,22 @@ def do_training(tr: Training, callback: tf.keras.callbacks.Callback, verbose=0): tr.model_name = dataset_name + "_" + str(tr.hyperparameters) tr.checkpoint_path = os.path.join(models_dir, tr.model_name, "checkpoints") tr.custom_objects = {'direction_metric':metrics.direction_metric, 'angle_metric':metrics.angle_metric} + model_path = os.path.join(models_dir, tr.model_name, "model") + + wandb.init(project="openbot") + + config = wandb.config + config.epochs = tr.hyperparameters.NUM_EPOCHS + config.learning_rate = tr.hyperparameters.LEARNING_RATE + config.batch_size = tr.hyperparameters.TRAIN_BATCH_SIZE + config["model_name"] = tr.model_name + append_logs = False model: tf.keras.Model if tr.hyperparameters.USE_LAST: append_logs = True - dirs = utils.list_dirs(tr.checkpoint_path) - last_checkpoint = sorted(dirs)[-1] model = tf.keras.models.load_model( - os.path.join(tr.checkpoint_path, last_checkpoint), + model_path, custom_objects=tr.custom_objects, compile=False, ) @@ -323,6 +335,10 @@ def do_training(tr: Training, callback: tf.keras.callbacks.Callback, verbose=0): tr.NETWORK_IMG_HEIGHT, tr.hyperparameters.BATCH_NORM, ) + dot_img_file = os.path.join(models_dir, tr.model_name, "model.png") + tf.keras.utils.plot_model(model, to_file=dot_img_file, show_shapes=True) + + callback.broadcast("model", tr.model_name) tr.loss_fn = losses.sq_weighted_mse_angle tr.metric_list = ["mean_absolute_error", tr.custom_objects['direction_metric'], tr.custom_objects['angle_metric']] @@ -350,49 +366,51 @@ def do_training(tr: Training, callback: tf.keras.callbacks.Callback, verbose=0): callbacks.checkpoint_cb(tr.checkpoint_path), callbacks.tensorboard_cb(tr.log_path), callbacks.logger_cb(tr.log_path, append_logs), + WandbCallback(), callback, ], ) + model.save(model_path) + wandb.save(model_path) + wandb.finish() def do_evaluation(tr: Training, callback: tf.keras.callbacks.Callback, verbose=0): callback.broadcast("message", "Generate plots...") - history = tr.history - log_path = tr.log_path - plt.plot(history.history["mean_absolute_error"], label="mean_absolute_error") - plt.plot(history.history["val_mean_absolute_error"], label="val_mean_absolute_error") + plt.plot(tr.history.history["mean_absolute_error"], label="mean_absolute_error") + plt.plot(tr.history.history["val_mean_absolute_error"], label="val_mean_absolute_error") plt.xlabel("Epoch") plt.ylabel("Mean Absolute Error") plt.legend(loc="lower right") - savefig(os.path.join(log_path, "error.png")) + savefig(os.path.join(tr.log_path, "error.png")) - plt.plot(history.history["direction_metric"], label="direction_metric") - plt.plot(history.history["val_direction_metric"], label="val_direction_metric") + plt.plot(tr.history.history["direction_metric"], label="direction_metric") + plt.plot(tr.history.history["val_direction_metric"], label="val_direction_metric") plt.xlabel("Epoch") plt.ylabel("Direction Metric") plt.legend(loc="lower right") - savefig(os.path.join(log_path, "direction.png")) + savefig(os.path.join(tr.log_path, "direction.png")) - plt.plot(history.history["angle_metric"], label="angle_metric") - plt.plot(history.history["val_angle_metric"], label="val_angle_metric") + plt.plot(tr.history.history["angle_metric"], label="angle_metric") + plt.plot(tr.history.history["val_angle_metric"], label="val_angle_metric") plt.xlabel("Epoch") plt.ylabel("Angle Metric") plt.legend(loc="lower right") - savefig(os.path.join(log_path, "angle.png")) + savefig(os.path.join(tr.log_path, "angle.png")) - plt.plot(history.history["loss"], label="loss") - plt.plot(history.history["val_loss"], label="val_loss") + plt.plot(tr.history.history["loss"], label="loss") + plt.plot(tr.history.history["val_loss"], label="val_loss") plt.xlabel("Epoch") plt.ylabel("Loss") plt.legend(loc="lower right") - savefig(os.path.join(log_path, "loss.png")) + savefig(os.path.join(tr.log_path, "loss.png")) callback.broadcast("message", "Generate tflite models...") checkpoint_path = tr.checkpoint_path print("checkpoint_path", checkpoint_path) best_index = np.argmax( - np.array(history.history["val_angle_metric"]) - + np.array(history.history["val_direction_metric"]) + np.array(tr.history.history["val_angle_metric"]) + + np.array(tr.history.history["val_direction_metric"]) ) best_checkpoint = str("cp-%04d.ckpt" % (best_index + 1)) best_tflite = utils.generate_tflite(checkpoint_path, best_checkpoint) @@ -400,8 +418,8 @@ def do_evaluation(tr: Training, callback: tf.keras.callbacks.Callback, verbose=0 print( "Best Checkpoint (val_angle: %s, val_direction: %s): %s" % ( - history.history["val_angle_metric"][best_index], - history.history["val_direction_metric"][best_index], + tr.history.history["val_angle_metric"][best_index], + tr.history.history["val_direction_metric"][best_index], best_checkpoint, ) ) @@ -412,8 +430,8 @@ def do_evaluation(tr: Training, callback: tf.keras.callbacks.Callback, verbose=0 print( "Last Checkpoint (val_angle: %s, val_direction: %s): %s" % ( - history.history["val_angle_metric"][-1], - history.history["val_direction_metric"][-1], + tr.history.history["val_angle_metric"][-1], + tr.history.history["val_direction_metric"][-1], last_checkpoint, ) ) @@ -441,7 +459,7 @@ def do_evaluation(tr: Training, callback: tf.keras.callbacks.Callback, verbose=0 utils.show_test_batch( image_batch.numpy(), cmd_batch.numpy(), label_batch.numpy(), pred_batch ) - savefig(os.path.join(log_path, "test_preview.png")) + savefig(os.path.join(tr.log_path, "test_preview.png")) utils.compare_tf_tflite(best_model, best_tflite) @@ -506,7 +524,7 @@ def create_tfrecord(callback: MyCallback): def broadcast(event, payload=None): print() print(event, payload) - + event = threading.Event() my_callback = MyCallback(broadcast, event) diff --git a/policy/policy_learning.ipynb b/policy/policy_learning.ipynb index 42406abe8..a064169d9 100644 --- a/policy/policy_learning.ipynb +++ b/policy/policy_learning.ipynb @@ -146,7 +146,7 @@ "metadata": {}, "outputs": [], "source": [ - "train.process_data(tr, redo_matching=False, remove_zeros=True)" + "train.process_data(tr)" ] }, { From 79e12a63474287fbdccc89a23992741badfb1435 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A1ndor=20Bal=C3=A1zs?= Date: Sat, 14 Aug 2021 22:53:37 +0200 Subject: [PATCH 02/16] android: add server selector to Autopilot and Logger, remove some duplication --- .../openbot/autopilot/AutopilotFragment.java | 93 +---------- .../org/openbot/common/ControlsFragment.java | 152 +++++++++++++++++- .../openbot/env/SharedPreferencesManager.java | 9 ++ .../org/openbot/logging/LoggerFragment.java | 37 +---- .../openbot/objectNav/ObjectNavFragment.java | 40 +---- .../org/openbot/original/CameraActivity.java | 4 + .../openbot/server/ServerCommunication.java | 42 ++++- .../org/openbot/server/ServerListener.java | 4 + .../res/layout-land/fragment_autopilot.xml | 30 ++-- .../main/res/layout-land/fragment_logger.xml | 115 +++++++------ .../main/res/layout/fragment_autopilot.xml | 29 ++-- .../src/main/res/layout/fragment_logger.xml | 75 +++++---- android/app/src/main/res/values/strings.xml | 6 + 13 files changed, 365 insertions(+), 271 deletions(-) diff --git a/android/app/src/main/java/org/openbot/autopilot/AutopilotFragment.java b/android/app/src/main/java/org/openbot/autopilot/AutopilotFragment.java index ec08abf33..73c0d0f58 100644 --- a/android/app/src/main/java/org/openbot/autopilot/AutopilotFragment.java +++ b/android/app/src/main/java/org/openbot/autopilot/AutopilotFragment.java @@ -14,19 +14,16 @@ import android.view.View; import android.view.ViewGroup; import android.widget.AdapterView; -import android.widget.ArrayAdapter; import android.widget.Toast; import androidx.annotation.NonNull; import androidx.annotation.Nullable; import androidx.camera.core.ImageProxy; import androidx.navigation.Navigation; import com.google.android.material.bottomsheet.BottomSheetBehavior; -import java.io.File; import java.io.IOException; import java.util.List; import java.util.Locale; import java.util.concurrent.TimeUnit; -import java.util.stream.Collectors; import org.jetbrains.annotations.NotNull; import org.openbot.R; import org.openbot.common.CameraFragment; @@ -34,25 +31,21 @@ import org.openbot.env.BorderedText; import org.openbot.env.Control; import org.openbot.env.ImageUtils; -import org.openbot.server.ServerCommunication; -import org.openbot.server.ServerListener; import org.openbot.tflite.Autopilot; import org.openbot.tflite.Model; import org.openbot.tflite.Network; import org.openbot.tracking.MultiBoxTracker; import org.openbot.utils.Constants; import org.openbot.utils.Enums; -import org.openbot.utils.FileUtils; import org.openbot.utils.PermissionUtils; import timber.log.Timber; -public class AutopilotFragment extends CameraFragment implements ServerListener { +public class AutopilotFragment extends CameraFragment { // options for drop down in object nav? private FragmentAutopilotBinding binding; private Handler handler; private HandlerThread handlerThread; - private ServerCommunication serverCommunication; private long lastProcessingTimeMs; private boolean computingNetwork = false; @@ -71,8 +64,6 @@ public class AutopilotFragment extends CameraFragment implements ServerListener private Network.Device device = Network.Device.CPU; private int numThreads = -1; - private ArrayAdapter modelAdapter; - @Override public void onCreate(@Nullable Bundle savedInstanceState) { super.onCreate(savedInstanceState); @@ -97,42 +88,11 @@ public void onViewCreated(@NonNull View view, @Nullable Bundle savedInstanceStat binding.threads.setText(String.valueOf(getNumThreads())); binding.cameraToggle.setOnClickListener(v -> toggleCamera()); - List models = - masterList.stream() - .filter(f -> f.type.equals(Model.TYPE.AUTOPILOT) && f.pathType != Model.PATH_TYPE.URL) - .map(f -> FileUtils.nameWithoutExtension(f.name)) - .collect(Collectors.toList()); - modelAdapter = new ArrayAdapter<>(requireContext(), R.layout.spinner_item, models); - - modelAdapter.setDropDownViewResource(android.R.layout.simple_dropdown_item_1line); - binding.modelSpinner.setAdapter(modelAdapter); - if (!preferencesManager.getAutopilotModel().isEmpty()) - binding.modelSpinner.setSelection( - Math.max( - 0, - modelAdapter.getPosition( - FileUtils.nameWithoutExtension(preferencesManager.getAutopilotModel())))); + List models = getModelNames(f -> f.type.equals(Model.TYPE.AUTOPILOT) && f.pathType != Model.PATH_TYPE.URL); + initModelSpinner(binding.modelSpinner, models, preferencesManager.getAutopilotModel()); + initServerSpinner(binding.serverSpinner); setAnalyserResolution(Enums.Preview.HD.getValue()); - binding.modelSpinner.setOnItemSelectedListener( - new AdapterView.OnItemSelectedListener() { - @Override - public void onItemSelected(AdapterView parent, View view, int position, long id) { - String selected = parent.getItemAtPosition(position).toString(); - try { - masterList.stream() - .filter(f -> f.name.contains(selected)) - .findFirst() - .ifPresent(value -> setModel(value)); - - } catch (IllegalArgumentException e) { - e.printStackTrace(); - } - } - - @Override - public void onNothingSelected(AdapterView parent) {} - }); binding.deviceSpinner.setOnItemSelectedListener( new AdapterView.OnItemSelectedListener() { @Override @@ -296,8 +256,6 @@ private void recreateNetwork(Model model, Network.Device device, int numThreads) @Override public synchronized void onResume() { - serverCommunication = new ServerCommunication(requireContext(), this); - serverCommunication.start(); handlerThread = new HandlerThread("inference"); handlerThread.start(); handler = new Handler(handlerThread.getLooper()); @@ -314,7 +272,6 @@ public synchronized void onPause() { } catch (final InterruptedException e) { e.printStackTrace(); } - serverCommunication.stop(); super.onPause(); } @@ -452,51 +409,11 @@ public void onConnectionEstablished(String ipAddress) { requireActivity().runOnUiThread(() -> binding.ipAddress.setText(ipAddress)); } - @Override - public void onAddModel(String model) { - Model item = - new Model( - masterList.size() + 1, - Model.CLASS.AUTOPILOT_F, - Model.TYPE.AUTOPILOT, - model, - Model.PATH_TYPE.FILE, - requireActivity().getFilesDir() + File.separator + model, - "256x96"); - - if (modelAdapter != null && modelAdapter.getPosition(model) == -1) { - modelAdapter.add(model); - masterList.add(item); - FileUtils.updateModelConfig(requireActivity(), masterList); - } else { - if (model.equals(binding.modelSpinner.getSelectedItem())) { - setModel(item); - } - } - Toast.makeText( - requireContext().getApplicationContext(), - "AutopilotModel added: " + model, - Toast.LENGTH_SHORT) - .show(); - } - - @Override - public void onRemoveModel(String model) { - if (modelAdapter != null && modelAdapter.getPosition(model) != -1) { - modelAdapter.remove(model); - } - Toast.makeText( - requireContext().getApplicationContext(), - "AutopilotModel removed: " + model, - Toast.LENGTH_SHORT) - .show(); - } - protected Model getModel() { return model; } - private void setModel(Model model) { + protected void setModel(Model model) { if (this.model != model) { Timber.d("Updating model: %s", model); this.model = model; diff --git a/android/app/src/main/java/org/openbot/common/ControlsFragment.java b/android/app/src/main/java/org/openbot/common/ControlsFragment.java index 27c7584d5..4cf6e81e0 100644 --- a/android/app/src/main/java/org/openbot/common/ControlsFragment.java +++ b/android/app/src/main/java/org/openbot/common/ControlsFragment.java @@ -7,13 +7,22 @@ import android.view.View; import android.view.animation.Animation; import android.view.animation.AnimationUtils; +import android.widget.AdapterView; +import android.widget.ArrayAdapter; +import android.widget.Spinner; +import android.widget.Toast; import androidx.activity.result.ActivityResultLauncher; import androidx.activity.result.contract.ActivityResultContracts; import androidx.annotation.NonNull; import androidx.annotation.Nullable; import androidx.fragment.app.Fragment; import androidx.lifecycle.ViewModelProvider; +import java.io.File; import java.util.List; +import java.util.Set; +import java.util.function.Predicate; +import java.util.stream.Collectors; +import org.jetbrains.annotations.NotNull; import org.json.JSONObject; import org.openbot.R; import org.openbot.env.AudioPlayer; @@ -24,6 +33,8 @@ import org.openbot.env.SharedPreferencesManager; import org.openbot.env.Vehicle; import org.openbot.main.MainViewModel; +import org.openbot.server.ServerCommunication; +import org.openbot.server.ServerListener; import org.openbot.tflite.Model; import org.openbot.utils.ConnectionUtils; import org.openbot.utils.Constants; @@ -33,7 +44,9 @@ import org.openbot.utils.PermissionUtils; import timber.log.Timber; -public abstract class ControlsFragment extends Fragment { +public abstract class ControlsFragment extends Fragment implements ServerListener { + private static final String NO_SERVER = "No server"; + protected MainViewModel mViewModel; protected Vehicle vehicle; protected Animation startAnimation; @@ -46,6 +59,12 @@ public abstract class ControlsFragment extends Fragment { protected final String voice = "matthew"; protected List masterList; + private ServerCommunication serverCommunication; + private ArrayAdapter modelAdapter; + private ArrayAdapter serverAdapter; + private Spinner modelSpinner; + private Spinner serverSpinner; + @Override public void onViewCreated(@NonNull View view, @Nullable Bundle savedInstanceState) { super.onViewCreated(view, savedInstanceState); @@ -58,6 +77,7 @@ public void onViewCreated(@NonNull View view, @Nullable Bundle savedInstanceStat preferencesManager = new SharedPreferencesManager(requireContext()); audioPlayer = new AudioPlayer(requireContext()); masterList = FileUtils.loadConfigJSONFromAsset(requireActivity()); + serverCommunication = new ServerCommunication(requireContext(), this); requireActivity() .getSupportFragmentManager() @@ -238,8 +258,17 @@ private void toggleIndicatorEvent(int value) { } }); + @NotNull + protected List getModelNames(Predicate filter) { + return masterList.stream() + .filter(filter) + .map(f -> FileUtils.nameWithoutExtension(f.name)) + .collect(Collectors.toList()); + } + @Override public void onResume() { + serverCommunication.start(); super.onResume(); } @@ -254,6 +283,7 @@ public void onDestroy() { @Override public synchronized void onPause() { Timber.d("onPause"); + serverCommunication.stop(); vehicle.setControl(0, 0); super.onPause(); } @@ -264,6 +294,126 @@ public void onStop() { super.onStop(); } + protected void initModelSpinner(Spinner spinner, List models, String selected) { + modelAdapter = new ArrayAdapter<>(requireContext(), R.layout.spinner_item, models); + modelAdapter.setDropDownViewResource(android.R.layout.simple_dropdown_item_1line); + modelSpinner = spinner; + modelSpinner.setAdapter(modelAdapter); + if (!selected.isEmpty()) + modelSpinner.setSelection( + Math.max(0, modelAdapter.getPosition(FileUtils.nameWithoutExtension(selected))) + ); + modelSpinner.setOnItemSelectedListener( + new AdapterView.OnItemSelectedListener() { + @Override + public void onItemSelected(AdapterView parent, View view, int position, long id) { + String selected = parent.getItemAtPosition(position).toString(); + try { + masterList.stream() + .filter(f -> f.name.contains(selected)) + .findFirst() + .ifPresent(value -> setModel(value)); + + } catch (IllegalArgumentException e) { + e.printStackTrace(); + } + } + + @Override + public void onNothingSelected(AdapterView parent) { + } + }); + } + + protected void initServerSpinner(Spinner spinner) { + serverAdapter = new ArrayAdapter<>(requireContext(), R.layout.spinner_item); + serverAdapter.setDropDownViewResource(android.R.layout.simple_dropdown_item_1line); + serverSpinner = spinner; + serverSpinner.setAdapter(serverAdapter); + serverSpinner.setOnItemSelectedListener( + new AdapterView.OnItemSelectedListener() { + @Override + public void onItemSelected(AdapterView parent, View view, int position, long id) { + String selected = parent.getItemAtPosition(position).toString(); + if (selected.equals(NO_SERVER)) { + serverCommunication.disconnect(); + if (serverAdapter.getPosition(preferencesManager.getServer()) > -1) { + preferencesManager.setServer(selected); + } + } else { + serverCommunication.connect(selected); + preferencesManager.setServer(selected); + } + } + + @Override + public void onNothingSelected(AdapterView parent) { + serverCommunication.disconnect(); + } + }); + onServerListChange(serverCommunication.getServers()); + } + + @Override + public void onServerListChange(Set servers) { + if (serverAdapter == null) { + return; + } + requireActivity().runOnUiThread(() -> { + serverAdapter.clear(); + serverAdapter.add(NO_SERVER); + serverAdapter.addAll(servers); + if (!preferencesManager.getServer().isEmpty()) { + serverSpinner.setSelection(Math.max(0, serverAdapter.getPosition(preferencesManager.getServer()))); + } + }); + } + + @Override + public void onAddModel(String model) { + Model item = + new Model( + masterList.size() + 1, + Model.CLASS.AUTOPILOT_F, + Model.TYPE.AUTOPILOT, + model, + Model.PATH_TYPE.FILE, + requireActivity().getFilesDir() + File.separator + model, + "256x96"); + + if (modelAdapter != null && modelAdapter.getPosition(model) == -1) { + modelAdapter.add(model); + masterList.add(item); + FileUtils.updateModelConfig(requireActivity(), masterList); + } else { + if (model.equals(modelSpinner.getSelectedItem())) { + setModel(item); + } + } + Toast.makeText( + requireContext().getApplicationContext(), + "AutopilotModel added: " + model, + Toast.LENGTH_SHORT) + .show(); + } + + @Override + public void onRemoveModel(String model) { + if (modelAdapter != null && modelAdapter.getPosition(model) != -1) { + modelAdapter.remove(model); + } + Toast.makeText( + requireContext().getApplicationContext(), + "AutopilotModel removed: " + model, + Toast.LENGTH_SHORT) + .show(); + } + + @Override + public void onConnectionEstablished(String ipAddress) {} + + protected void setModel(Model model) {} + protected abstract void processControllerKeyData(String command); protected abstract void processUSBData(String data); diff --git a/android/app/src/main/java/org/openbot/env/SharedPreferencesManager.java b/android/app/src/main/java/org/openbot/env/SharedPreferencesManager.java index a5fa724e1..5d9dc5180 100644 --- a/android/app/src/main/java/org/openbot/env/SharedPreferencesManager.java +++ b/android/app/src/main/java/org/openbot/env/SharedPreferencesManager.java @@ -22,6 +22,7 @@ public class SharedPreferencesManager { private static final String DEFAULT_MODEL = "DEFAULT_MODEL_NAME"; private static final String OBJECT_NAV_MODEL = "OBJECT_NAV_MODEL_NAME"; private static final String AUTOPILOT_MODEL = "AUTOPILOT_MODEL_NAME"; + private static final String SERVER_NAME = "SERVER_NAME"; private static final String OBJECT_TYPE = "OBJECT_TYPE"; private static final String DEFAULT_OBJECT_TYPE = "person"; @@ -107,6 +108,14 @@ public String getAutopilotModel() { return preferences.getString(AUTOPILOT_MODEL, ""); } + public void setServer(String server) { + preferences.edit().putString(SERVER_NAME, server).apply(); + } + + public String getServer() { + return preferences.getString(SERVER_NAME, ""); + } + public void setObjectType(String model) { preferences.edit().putString(OBJECT_TYPE, model).apply(); } diff --git a/android/app/src/main/java/org/openbot/logging/LoggerFragment.java b/android/app/src/main/java/org/openbot/logging/LoggerFragment.java index f788a0a36..1ede08640 100644 --- a/android/app/src/main/java/org/openbot/logging/LoggerFragment.java +++ b/android/app/src/main/java/org/openbot/logging/LoggerFragment.java @@ -20,7 +20,6 @@ import android.view.View; import android.view.ViewGroup; import android.widget.AdapterView; -import android.widget.ArrayAdapter; import androidx.activity.result.ActivityResultLauncher; import androidx.activity.result.contract.ActivityResultContracts; import androidx.annotation.NonNull; @@ -34,7 +33,6 @@ import java.util.List; import java.util.Locale; import java.util.concurrent.TimeUnit; -import java.util.stream.Collectors; import org.jetbrains.annotations.NotNull; import org.openbot.R; import org.openbot.common.CameraFragment; @@ -113,30 +111,10 @@ public void onViewCreated(@NonNull View view, @Nullable Bundle savedInstanceStat binding.cameraToggle.setOnClickListener(v -> toggleCamera()); - List models = - masterList.stream() - .filter(f -> f.pathType != Model.PATH_TYPE.URL) - .map(f -> org.openbot.utils.FileUtils.nameWithoutExtension(f.name)) - .collect(Collectors.toList()); - - ArrayAdapter modelAdapter = - new ArrayAdapter<>(requireContext(), R.layout.spinner_item, models); - modelAdapter.setDropDownViewResource(android.R.layout.simple_dropdown_item_1line); - binding.modelSpinner.setAdapter(modelAdapter); - binding.modelSpinner.setOnItemSelectedListener( - new AdapterView.OnItemSelectedListener() { - @Override - public void onItemSelected(AdapterView parent, View view, int position, long id) { - String selected = parent.getItemAtPosition(position).toString(); - masterList.stream() - .filter(f -> f.name.contains(selected)) - .findFirst() - .ifPresent(f -> updateCropImageInfo(f)); - } + List models = getModelNames(f -> f.pathType != Model.PATH_TYPE.URL); + initModelSpinner(binding.modelSpinner, models, ""); + initServerSpinner(binding.serverSpinner); - @Override - public void onNothingSelected(AdapterView parent) {} - }); binding.resolutionSpinner.setOnItemSelectedListener( new AdapterView.OnItemSelectedListener() { @Override @@ -173,7 +151,8 @@ public void onNothingSelected(AdapterView parent) {} }); } - private void updateCropImageInfo(Model selected) { + @Override + protected void setModel(Model selected) { frameToCropTransform = null; binding.cropInfo.setText( String.format( @@ -553,10 +532,4 @@ protected void processFrame(Bitmap bitmap, ImageProxy image) { public void onConnectionEstablished(String ipAddress) { requireActivity().runOnUiThread(() -> binding.ipAddress.setText(ipAddress)); } - - @Override - public void onAddModel(String model) {} - - @Override - public void onRemoveModel(String model) {} } diff --git a/android/app/src/main/java/org/openbot/objectNav/ObjectNavFragment.java b/android/app/src/main/java/org/openbot/objectNav/ObjectNavFragment.java index c93bbf984..197dd40a3 100644 --- a/android/app/src/main/java/org/openbot/objectNav/ObjectNavFragment.java +++ b/android/app/src/main/java/org/openbot/objectNav/ObjectNavFragment.java @@ -28,7 +28,6 @@ import java.util.List; import java.util.Locale; import java.util.concurrent.TimeUnit; -import java.util.stream.Collectors; import org.jetbrains.annotations.NotNull; import org.openbot.R; import org.openbot.common.CameraFragment; @@ -42,7 +41,6 @@ import org.openbot.tracking.MultiBoxTracker; import org.openbot.utils.Constants; import org.openbot.utils.Enums; -import org.openbot.utils.FileUtils; import org.openbot.utils.PermissionUtils; import timber.log.Timber; @@ -132,41 +130,10 @@ public void onNothingSelected(AdapterView parent) {} binding.cameraToggle.setOnClickListener(v -> toggleCamera()); - List models = - masterList.stream() - .filter(f -> f.type.equals(Model.TYPE.DETECTOR) && f.pathType != Model.PATH_TYPE.URL) - .map(f -> FileUtils.nameWithoutExtension(f.name)) - .collect(Collectors.toList()); - ArrayAdapter modelAdapter = - new ArrayAdapter<>(requireContext(), R.layout.spinner_item, models); - - modelAdapter.setDropDownViewResource(android.R.layout.simple_dropdown_item_1line); - binding.modelSpinner.setAdapter(modelAdapter); - if (!preferencesManager.getObjectNavModel().isEmpty()) - binding.modelSpinner.setSelection( - Math.max( - 0, - modelAdapter.getPosition( - FileUtils.nameWithoutExtension(preferencesManager.getObjectNavModel())))); + List models = getModelNames(f -> f.type.equals(Model.TYPE.DETECTOR) && f.pathType != Model.PATH_TYPE.URL); + initModelSpinner(binding.modelSpinner, models, preferencesManager.getObjectNavModel()); setAnalyserResolution(Enums.Preview.HD.getValue()); - binding.modelSpinner.setOnItemSelectedListener( - new AdapterView.OnItemSelectedListener() { - @Override - public void onItemSelected(AdapterView parent, View view, int position, long id) { - String selected = parent.getItemAtPosition(position).toString(); - try { - masterList.stream() - .filter(f -> f.name.contains(selected)) - .findFirst() - .ifPresent(value -> setModel(value)); - } catch (IllegalArgumentException e) { - } - } - - @Override - public void onNothingSelected(AdapterView parent) {} - }); binding.deviceSpinner.setOnItemSelectedListener( new AdapterView.OnItemSelectedListener() { @Override @@ -508,7 +475,8 @@ protected Model getModel() { return model; } - private void setModel(Model model) { + @Override + protected void setModel(Model model) { if (this.model != model) { Timber.d("Updating model: %s", model); this.model = model; diff --git a/android/app/src/main/java/org/openbot/original/CameraActivity.java b/android/app/src/main/java/org/openbot/original/CameraActivity.java index 2dcd6b065..70e5df2cf 100755 --- a/android/app/src/main/java/org/openbot/original/CameraActivity.java +++ b/android/app/src/main/java/org/openbot/original/CameraActivity.java @@ -77,6 +77,7 @@ import java.util.List; import java.util.Locale; import java.util.Objects; +import java.util.Set; import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; import org.json.JSONObject; @@ -612,6 +613,9 @@ public void onRemoveModel(String model) { Toast.makeText(context, "AutopilotModel removed: " + model, Toast.LENGTH_SHORT).show(); } + @Override + public void onServerListChange(Set servers) {} + @Override public void onConnectionEstablished(String ipAddress) {} diff --git a/android/app/src/main/java/org/openbot/server/ServerCommunication.java b/android/app/src/main/java/org/openbot/server/ServerCommunication.java index 83966df33..a1d0f5479 100644 --- a/android/app/src/main/java/org/openbot/server/ServerCommunication.java +++ b/android/app/src/main/java/org/openbot/server/ServerCommunication.java @@ -11,7 +11,10 @@ import cz.msebera.android.httpclient.Header; import java.io.File; import java.io.FileNotFoundException; +import java.util.HashMap; import java.util.HashSet; +import java.util.Map; +import java.util.Set; import java.util.Timer; import java.util.TimerTask; import org.json.JSONArray; @@ -25,6 +28,7 @@ public class ServerCommunication { private final AsyncHttpClient client; private final Context context; private final NsdService nsdService; + private final Map servers = new HashMap<>(); private final NsdManager.ResolveListener resolveListener = new NsdManager.ResolveListener() { @Override @@ -35,13 +39,12 @@ public void onResolveFailed(NsdServiceInfo serviceInfo, int errorCode) { @Override public void onServiceResolved(NsdServiceInfo serviceInfo) { - nsdService.stop(); - serverUrl = - "http://" + serviceInfo.getHost().getHostAddress() + ":" + serviceInfo.getPort(); - Timber.d("Resolved address: %s", serverUrl); - - client.get(context, serverUrl + "/test", testResponseHandler); - serverListener.onConnectionEstablished(serverUrl); + servers.put(serviceInfo.getServiceName(), serviceInfo); + try { + serverListener.onServerListChange(servers.keySet()); + } catch (Exception e) { + Timber.w(e); + } } }; private final JsonHttpResponseHandler testResponseHandler = @@ -163,6 +166,26 @@ public void run() { 10000); } + public void connect(String server) { + NsdServiceInfo serviceInfo = servers.get(server); + if (serviceInfo == null) { + Timber.e("Server not found: %s", server); + return; + } + String ipAddress = serviceInfo.getHost().getHostAddress(); + serverUrl = "http://" + ipAddress + ":" + serviceInfo.getPort(); + Timber.d("Resolved address: %s", serverUrl); + + client.get(context, serverUrl + "/test", testResponseHandler); + serverListener.onConnectionEstablished(ipAddress); + } + + public void disconnect() { + client.cancelRequests(context, true); + serverUrl = null; + serverListener.onConnectionEstablished(context.getString(R.string.ip_placeholder)); + } + public void upload(File file) { if (serverUrl == null) { return; @@ -202,9 +225,14 @@ public void uploadAll() { public void stop() { client.cancelRequests(context, true); + nsdService.stop(); timer.cancel(); } + public Set getServers() { + return servers.keySet(); + } + static class UploadResponseHandler extends JsonHttpResponseHandler { private final File file; diff --git a/android/app/src/main/java/org/openbot/server/ServerListener.java b/android/app/src/main/java/org/openbot/server/ServerListener.java index d2aa466d1..05bbe6106 100644 --- a/android/app/src/main/java/org/openbot/server/ServerListener.java +++ b/android/app/src/main/java/org/openbot/server/ServerListener.java @@ -1,9 +1,13 @@ package org.openbot.server; +import java.util.Set; + public interface ServerListener { void onAddModel(String model); void onRemoveModel(String model); void onConnectionEstablished(String ipAddress); + + void onServerListChange(Set servers); } diff --git a/android/app/src/main/res/layout-land/fragment_autopilot.xml b/android/app/src/main/res/layout-land/fragment_autopilot.xml index e91ea46aa..ddbef5928 100644 --- a/android/app/src/main/res/layout-land/fragment_autopilot.xml +++ b/android/app/src/main/res/layout-land/fragment_autopilot.xml @@ -71,25 +71,33 @@ android:layout_height="wrap_content"> + android:layout_weight="1.2" + android:gravity="center_vertical|start" + android:paddingHorizontal="8dp" + android:text="@string/ip_placeholder" + android:textColor="@android:color/black" /> + + + diff --git a/android/app/src/main/res/layout-land/fragment_logger.xml b/android/app/src/main/res/layout-land/fragment_logger.xml index 74a206171..ea84ceb14 100644 --- a/android/app/src/main/res/layout-land/fragment_logger.xml +++ b/android/app/src/main/res/layout-land/fragment_logger.xml @@ -32,8 +32,8 @@ android:id="@+id/usbToggle" android:layout_width="wrap_content" android:layout_height="wrap_content" - android:button="@drawable/usb_toggle" android:layout_marginEnd="16dp" + android:button="@drawable/usb_toggle" app:layout_constraintBottom_toBottomOf="@+id/camera_toggle" app:layout_constraintEnd_toStartOf="@+id/camera_toggle" app:layout_constraintTop_toTopOf="@+id/camera_toggle" /> @@ -90,7 +90,7 @@ android:layout_width="0dp" android:layout_height="0dp" android:layout_marginStart="16dp" - android:layout_marginEnd="16dp" + android:layout_marginEnd="8dp" android:entries="@array/preview_resolutions" android:gravity="center" android:prompt="@string/preview_resolution" @@ -103,14 +103,27 @@ android:id="@+id/model_spinner" android:layout_width="0dp" android:layout_height="0dp" - android:layout_marginStart="16dp" - android:layout_marginEnd="16dp" - tools:entries="@array/models" + android:layout_marginEnd="8dp" android:prompt="@string/model" app:layout_constraintBottom_toBottomOf="@+id/crop_info" app:layout_constraintEnd_toEndOf="parent" - app:layout_constraintStart_toEndOf="@+id/crop_info" - app:layout_constraintTop_toTopOf="@+id/crop_info" /> + app:layout_constraintStart_toStartOf="@+id/resolution_spinner" + app:layout_constraintTop_toTopOf="@+id/crop_info" + tools:entries="@array/models" /> + + @@ -135,6 +148,29 @@ app:layout_constraintStart_toEndOf="@+id/analyseText" app:layout_constraintTop_toTopOf="@+id/analyseText" /> + + + + +