Skip to content

Commit

Permalink
Nailed TF version to 2.13 as >= 2.14 has problems with the GPU.
Browse files Browse the repository at this point in the history
  • Loading branch information
DocGarbanzo committed Mar 10, 2024
1 parent 759da00 commit fead468
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 39 deletions.
9 changes: 6 additions & 3 deletions donkeycar/parts/keras_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -999,19 +999,22 @@ def square_plus_controller(in_tensor, size, l2, seq_len=None,


def create_name(has_lap_pct, imu_dim, mem_len, multi_input, seq_len, size):
name = None

if multi_input:
if imu_dim:
name = f'SquarePlusImu_{size}_{imu_dim}'
elif has_lap_pct:
name = f'SquarePlusMemLap_{size}_{mem_len}'
elif mem_len:
name = f'SquarePlusMem_{size}_{mem_len}'
assert seq_len is None, "SquarePlusMem doesn't work with LSTM"
else:
raise RuntimeError("Needs imu dim or mem length")
name = f'SquarePlusMem_{size}_{mem_len}'
else:
name = 'SquarePlus_' + size

if seq_len:
name += '_lstm_' + str(seq_len)

return name


68 changes: 42 additions & 26 deletions donkeycar/parts/tub_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,48 +28,64 @@ def __init__(self, tub: Tub, gyro_z_index: int = 1):
f' assuming use {"gym" if gyro_z_index==1 else "real"}')

def generate_laptimes_from_records(self, overwrite=False):

def new_session(session_id, lap_times, this_session_id, this_lap,
record):
if session_id is not None:
# copy results of current session
res[session_id] = copy(lap_times)
# reset lap times
lap_times.clear()

session_id = this_session_id
lap = this_lap
time_stamp_ms = record['_timestamp_ms']
dist = record['car/distance']
return session_id, lap, time_stamp_ms, dist, lap_times

session_id = None
lap = 0
dist = 0
time_stamp_ms = None
lap_times = []
res = {}
# self is iterable

for record in self.tub:
this_session_id = record.get('_session_id')
this_lap = record['car/lap']

if this_session_id != session_id:
# stepping into new session
if session_id:
# copy results of current session
res[session_id] = copy(lap_times)
# reset lap_times and lap
lap_times.clear()
lap = this_lap
session_id = this_session_id
time_stamp_ms = record['_timestamp_ms']
dist = record['car/distance']

if this_lap != lap:
assert this_lap > lap, f'Found smaller lap {this_lap} than ' \
f'previous lap {lap} in session {session_id}'
this_time_stamp_ms = record['_timestamp_ms']
lap_time = (this_time_stamp_ms - time_stamp_ms) / 1000
this_dist = record['car/distance']
lap_dist = this_dist - dist
lap_times.append(dict(lap=lap, time=lap_time, distance=lap_dist))
lap = this_lap
time_stamp_ms = this_time_stamp_ms
dist = this_dist
# add last session id
session_id, lap, time_stamp_ms, dist, lap_times = new_session(
session_id, lap_times, this_session_id, this_lap, record)
continue

if this_lap == lap:
continue

assert this_lap > lap, (f'Found smaller lap {this_lap} than previous'
f' lap {lap} in session {session_id}')

this_time_stamp_ms = record['_timestamp_ms']
lap_time = (this_time_stamp_ms - time_stamp_ms) / 1000
this_dist = record['car/distance']
lap_dist = this_dist - dist
lap_times.append(dict(lap=lap, time=lap_time, distance=lap_dist))

lap = this_lap
time_stamp_ms = this_time_stamp_ms
dist = this_dist

assert session_id is not None, "Session id should not be None"
res[session_id] = lap_times

for sess_id, lap_times in res.items():
meta_session_id_dict = self.tub.manifest.metadata.get(sess_id)
if not meta_session_id_dict:
self.tub.manifest.metadata[sess_id] = dict(laptimer=lap_times)
elif 'laptimer' in meta_session_id_dict and overwrite or \
'laptimer' not in meta_session_id_dict:
elif ('laptimer' in meta_session_id_dict and overwrite
or 'laptimer' not in meta_session_id_dict):
meta_session_id_dict['laptimer'] = lap_times

self.tub.manifest.write_metadata()
logger.info(f'Generated lap times {res}')

Expand Down
20 changes: 10 additions & 10 deletions donkeycar/pipeline/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,19 @@ def __init__(self, cfg: Config) -> None:
self.entries = self.read()

def read(self) -> List[Dict]:
if os.path.exists(self.path):
try:
with open(self.path, "r") as read_file:
data = json.load(read_file)
logger.info(f'Found model database {self.path}')
return data
except Exception as e:
logger.error(f"Could not open database file because: {e}")
return []
else:
if not os.path.exists(self.path):
logger.warning(f'No model database found at {self.path}')
return []

try:
with open(self.path, "r") as read_file:
data = json.load(read_file)
logger.info(f'Found model database {self.path}')
return data
except Exception as e:
logger.error(f"Could not open database file because: {e}")
return []

def generate_model_name(self) -> Tuple[str, int]:
if self.entries:
df = self.to_df()
Expand Down

0 comments on commit fead468

Please sign in to comment.