Skip to content

Commit

Permalink
model update
Browse files Browse the repository at this point in the history
  • Loading branch information
gtanczyk committed Oct 14, 2024
1 parent 6a55c31 commit bb53fbe
Show file tree
Hide file tree
Showing 17 changed files with 1,161 additions and 9,462 deletions.
2 changes: 1 addition & 1 deletion games/masterplan/src/public/tfmodel/model.json
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"modelTopology":{"class_name":"Sequential","config":{"name":"sequential_1","layers":[{"class_name":"LSTM","config":{"name":"lstm_LSTM1","trainable":true,"batch_input_shape":[null,15,270],"dtype":"float32","units":30,"activation":"relu","recurrent_activation":"hard_sigmoid","use_bias":true,"kernel_initializer":{"class_name":"VarianceScaling","config":{"scale":1,"mode":"fan_avg","distribution":"normal","seed":null}},"recurrent_initializer":{"class_name":"Orthogonal","config":{"gain":1,"seed":null}},"bias_initializer":{"class_name":"Zeros","config":{}},"unit_forget_bias":null,"kernel_regularizer":null,"recurrent_regularizer":null,"bias_regularizer":null,"activity_regularizer":null,"kernel_constraint":null,"recurrent_constraint":null,"bias_constraint":null,"dropout":0,"recurrent_dropout":0,"implementation":null,"return_sequences":true,"return_state":false,"go_backwards":false,"stateful":false,"unroll":false}},{"class_name":"TimeDistributed","config":{"layer":{"class_name":"Dense","config":{"units":270,"activation":"linear","use_bias":true,"kernel_initializer":{"class_name":"VarianceScaling","config":{"scale":1,"mode":"fan_avg","distribution":"normal","seed":null}},"bias_initializer":{"class_name":"Zeros","config":{}},"kernel_regularizer":null,"bias_regularizer":null,"activity_regularizer":null,"kernel_constraint":null,"bias_constraint":null,"name":"dense_Dense1","trainable":true}},"name":"time_distributed_TimeDistributed1","trainable":true}}]},"keras_version":"tfjs-layers 4.21.0","backend":"tensor_flow.js"},"weightsManifest":[{"paths":["weights.bin"],"weights":[{"name":"lstm_LSTM1/kernel","shape":[270,120],"dtype":"float32"},{"name":"lstm_LSTM1/recurrent_kernel","shape":[30,120],"dtype":"float32"},{"name":"lstm_LSTM1/bias","shape":[120],"dtype":"float32"},{"name":"time_distributed_TimeDistributed1/kernel","shape":[30,270],"dtype":"float32"},{"name":"time_distributed_TimeDistributed1/bias","shape":[270],"dtype":"float32"}]}],"format":"layers-model","generatedBy":"TensorFlow.js tfjs-layers v4.21.0","convertedBy":null}
{"modelTopology":{"class_name":"Sequential","config":{"name":"sequential_1","layers":[{"class_name":"LSTM","config":{"name":"lstm_LSTM1","trainable":true,"batch_input_shape":[null,15,150],"dtype":"float32","units":30,"activation":"relu","recurrent_activation":"hard_sigmoid","use_bias":true,"kernel_initializer":{"class_name":"VarianceScaling","config":{"scale":1,"mode":"fan_avg","distribution":"normal","seed":null}},"recurrent_initializer":{"class_name":"Orthogonal","config":{"gain":1,"seed":null}},"bias_initializer":{"class_name":"Zeros","config":{}},"unit_forget_bias":null,"kernel_regularizer":null,"recurrent_regularizer":null,"bias_regularizer":null,"activity_regularizer":null,"kernel_constraint":null,"recurrent_constraint":null,"bias_constraint":null,"dropout":0,"recurrent_dropout":0,"implementation":null,"return_sequences":true,"return_state":false,"go_backwards":false,"stateful":false,"unroll":false}},{"class_name":"TimeDistributed","config":{"layer":{"class_name":"Dense","config":{"units":270,"activation":"linear","use_bias":true,"kernel_initializer":{"class_name":"VarianceScaling","config":{"scale":1,"mode":"fan_avg","distribution":"normal","seed":null}},"bias_initializer":{"class_name":"Zeros","config":{}},"kernel_regularizer":null,"bias_regularizer":null,"activity_regularizer":null,"kernel_constraint":null,"bias_constraint":null,"name":"dense_Dense1","trainable":true}},"name":"time_distributed_TimeDistributed1","trainable":true}}]},"keras_version":"tfjs-layers 4.21.0","backend":"tensor_flow.js"},"weightsManifest":[{"paths":["weights.bin"],"weights":[{"name":"lstm_LSTM1/kernel","shape":[150,120],"dtype":"float32"},{"name":"lstm_LSTM1/recurrent_kernel","shape":[30,120],"dtype":"float32"},{"name":"lstm_LSTM1/bias","shape":[120],"dtype":"float32"},{"name":"time_distributed_TimeDistributed1/kernel","shape":[30,270],"dtype":"float32"},{"name":"time_distributed_TimeDistributed1/bias","shape":[270],"dtype":"float32"}]}],"format":"layers-model","generatedBy":"TensorFlow.js tfjs-layers v4.21.0","convertedBy":null}
Binary file modified games/masterplan/src/public/tfmodel/weights.bin
Binary file not shown.
5 changes: 3 additions & 2 deletions games/masterplan/src/screens/battle/game/game-render.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,17 @@ import { getCanvas } from '../util/canvas';
import { LAYER_DEFAULT } from '../consts';
import { GameWorld } from './game-world';
import { RenderQueue } from './game-render-queue';
import { GameWorldRender } from './game-world-render';

export function renderGame(world: GameWorld) {
export function renderGame(world: GameWorld, worldRender: GameWorldRender) {
const canvas = getCanvas(LAYER_DEFAULT);
const renderQueue = new RenderQueue();

// clear
canvas.clear();

// render terrain
world.terrain.render(canvas);
canvas.drawImage(worldRender.terrainCanvas, 0, 0);

// set camera
canvas.save().translate(canvas.getWidth() / 2, canvas.getHeight() / 2);
Expand Down
15 changes: 15 additions & 0 deletions games/masterplan/src/screens/battle/game/game-world-render.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import { GameWorld } from './game-world';
import { createTerrainTexture } from './terrain/terrain-renderer';

export class GameWorldRender {
terrainCanvas: HTMLCanvasElement;

constructor(gameWorld: GameWorld) {
this.terrainCanvas = createTerrainTexture(
gameWorld.terrain.width,
gameWorld.terrain.height,
gameWorld.terrain.heightMap,
gameWorld.terrain.tileSize,
);
}
}
14 changes: 3 additions & 11 deletions games/masterplan/src/screens/battle/game/terrain/terrain.ts
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
import { EDGE_RADIUS } from '../../consts';
import { Canvas } from '../../util/canvas';
import { Vec } from '../../util/vmath';
import { TerrainData } from './terrain-generator';
import { createTerrainTexture } from './terrain-renderer';

export class Terrain {
tileSize = 32;
offsetX = (EDGE_RADIUS * 1.5) / this.tileSize;
offsetY = EDGE_RADIUS / this.tileSize;
terrainCanvas: HTMLCanvasElement;
heightMap: number[][];
width: number;
height: number;
Expand All @@ -17,22 +14,17 @@ export class Terrain {
this.width = terrainData.width;
this.height = terrainData.height;
this.heightMap = terrainData.heightMap;
this.terrainCanvas = createTerrainTexture(this.width, this.height, this.heightMap, this.tileSize);
}

render(ctx: Canvas) {
ctx.drawImage(this.terrainCanvas, 0, 0);
}

getHeightAt(pos: Vec): number {
// Calculate tile indices and positions within the tile
const x = pos[0] / this.tileSize + this.offsetX;
const y = pos[1] / this.tileSize + this.offsetY;

const x0 = Math.floor(x);
const x0 = Math.min(Math.max(Math.floor(x), 0), this.heightMap.length - 1);
const x1 = Math.min(x0 + 1, this.heightMap[0].length - 1);
const y0 = Math.floor(y);
const y1 = Math.min(y0 + 1, this.heightMap.length - 1);
const y0 = Math.min(Math.max(Math.floor(y), 0), this.heightMap.length - 1);
const y1 = Math.max(Math.min(y0 + 1, this.heightMap.length - 1), 0);

// Get heights at the four corners
const h00 = this.heightMap[y0][x0];
Expand Down
61 changes: 61 additions & 0 deletions games/masterplan/src/screens/battle/model/convert-input.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import { Unit, UnitType } from '../../designer/designer-types';
import { MAX_COL, MAX_ROW } from '../consts';
import { TerrainData } from '../game/terrain/terrain-generator';
import {
INPUT_CELL_INDEX_MAP,
INPUT_CELL_UNIT_INDEX_MAP,
INPUT_COLS,
INPUT_ROWS,
ModelInput,
ModelInputCell,
} from './types';

export function unitsToModelInput(units: Unit[], terrainData: TerrainData): ModelInput {
const cols = INPUT_COLS;
const rows = INPUT_ROWS;
const data = Array.from({ length: rows }, () =>
Array.from({ length: cols }, () => Array(Object.keys(INPUT_CELL_INDEX_MAP).length).fill(0) as ModelInputCell),
);

let maxUnitCount = 0;

for (const u of units) {
for (let i = 0; i < u.sizeRow; i++) {
for (let j = 0; j < u.sizeCol; j++) {
const gameRow = MAX_ROW / 2 + u.row + i;
const gameCol = MAX_COL / 2 + u.col + j;
const inputRow = Math.min(Math.floor((gameRow * rows) / MAX_ROW), INPUT_ROWS - 1);
const inputCol = Math.min(Math.floor((gameCol * cols) / MAX_COL), INPUT_COLS - 1);
if (data[inputRow] && data[inputRow][inputCol]) {
data[inputRow][inputCol][INPUT_CELL_INDEX_MAP[u.type]] =
1 + (data[inputRow][inputCol][INPUT_CELL_INDEX_MAP[u.type]] || 0);
maxUnitCount = Math.max(maxUnitCount, data[inputRow][inputCol][INPUT_CELL_INDEX_MAP[u.type]]);
}
}
}
}

const maxHeight = Math.max(...terrainData.heightMap.flat());

// Add terrain height to the input
for (let row = 0; row < rows; row++) {
for (let col = 0; col < cols; col++) {
const gameRow = Math.floor((row * MAX_ROW) / rows);
const gameCol = Math.floor((col * MAX_COL) / cols);
const height = terrainData.heightMap[gameRow][gameCol];
data[row][col][INPUT_CELL_INDEX_MAP.terrainHeight] = height / maxHeight;
}
}

// Normalize unit counts using maxUnitCount
for (const row of data) {
for (const cell of row) {
for (const key in INPUT_CELL_UNIT_INDEX_MAP) {
cell[INPUT_CELL_UNIT_INDEX_MAP[key as UnitType]] =
cell[INPUT_CELL_UNIT_INDEX_MAP[key as UnitType]] / maxUnitCount;
}
}
}

return { cols, rows, data };
}
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
import { rotateUnits, Unit, UnitType } from '../../designer/designer-types';
import { MAX_COL, MAX_ROW } from '../consts';
import { CELL_INDEX_MAP, COMMAND_INDEX_MAP, INPUT_COLS, INPUT_ROWS, ModelInput, ModelInputCell } from './types';
import { MAX_ROW, MAX_COL } from '../consts';
import {
ModelOutput,
INPUT_COLS,
INPUT_ROWS,
OUTPUT_CELL_INDEX_MAP,
ModelOutputCell,
OUTPUT_CELL_UNIT_INDEX_MAP,
OUTPUT_CELL_COMMAND_INDEX_MAP,
} from './types';

export function unitsToModelInput(units: Unit[]): ModelInput {
export function unitsToModelOutput(units: Unit[]): ModelOutput {
const cols = INPUT_COLS;
const rows = INPUT_ROWS;
const data = Array.from({ length: rows }, () =>
Array.from(
{ length: cols },
() => Array(Object.keys(CELL_INDEX_MAP).length + Object.keys(COMMAND_INDEX_MAP).length).fill(0) as ModelInputCell,
),
Array.from({ length: cols }, () => Array(Object.keys(OUTPUT_CELL_INDEX_MAP).length).fill(0) as ModelOutputCell),
);

let maxUnitCount = 0;
Expand All @@ -22,11 +27,11 @@ export function unitsToModelInput(units: Unit[]): ModelInput {
const inputRow = Math.min(Math.floor((gameRow * rows) / MAX_ROW), INPUT_ROWS - 1);
const inputCol = Math.min(Math.floor((gameCol * cols) / MAX_COL), INPUT_COLS - 1);
if (data[inputRow] && data[inputRow][inputCol]) {
data[inputRow][inputCol][CELL_INDEX_MAP[u.type]] =
1 + (data[inputRow][inputCol][CELL_INDEX_MAP[u.type]] || 0);
maxUnitCount = Math.max(maxUnitCount, data[inputRow][inputCol][CELL_INDEX_MAP[u.type]]);
if (u.command && COMMAND_INDEX_MAP[u.command]) {
data[inputRow][inputCol][COMMAND_INDEX_MAP[u.command]] = 1;
data[inputRow][inputCol][OUTPUT_CELL_INDEX_MAP[u.type]] =
1 + (data[inputRow][inputCol][OUTPUT_CELL_INDEX_MAP[u.type]] || 0);
maxUnitCount = Math.max(maxUnitCount, data[inputRow][inputCol][OUTPUT_CELL_INDEX_MAP[u.type]]);
if (u.command && OUTPUT_CELL_INDEX_MAP[u.command]) {
data[inputRow][inputCol][OUTPUT_CELL_INDEX_MAP[u.command]] = 1;
}
}
}
Expand All @@ -36,36 +41,36 @@ export function unitsToModelInput(units: Unit[]): ModelInput {
// Normalize unit counts using maxUnitCount
for (const row of data) {
for (const cell of row) {
for (const key in CELL_INDEX_MAP) {
cell[CELL_INDEX_MAP[key as UnitType]] = cell[CELL_INDEX_MAP[key as UnitType]] / maxUnitCount;
for (const key in OUTPUT_CELL_UNIT_INDEX_MAP) {
cell[OUTPUT_CELL_UNIT_INDEX_MAP[key as UnitType]] =
cell[OUTPUT_CELL_UNIT_INDEX_MAP[key as UnitType]] / maxUnitCount;
}
}
}

return { cols, rows, data };
}

export function modelInputToUnits(input: ModelInput): Unit[] {
export function modelOutputToUnits(output: ModelOutput): Unit[] {
const units: Unit[] = [];

const inputMap: Record<string, { type: UnitType; command?: Unit['command'] }> = {};

const unitValues = input.data
.map((row) => row.map((col) => Object.values(CELL_INDEX_MAP).map((idx) => col[idx])))
const unitValues = output.data
.map((row) => row.map((col) => Object.values(OUTPUT_CELL_UNIT_INDEX_MAP).map((idx) => col[idx])))
.flat(2);
const min = Math.min(...unitValues);
const max = Math.max(...unitValues);
const threshold = min + (max - min) * 0.5;

for (let i = 0; i < input.rows; i++) {
for (let j = 0; j < input.cols; j++) {
const type = Object.keys(CELL_INDEX_MAP)
.map((k) => [k, input.data[i][j][CELL_INDEX_MAP[k as UnitType]]] as const)
for (let i = 0; i < output.rows; i++) {
for (let j = 0; j < output.cols; j++) {
const type = Object.keys(OUTPUT_CELL_UNIT_INDEX_MAP)
.map((k) => [k, output.data[i][j][OUTPUT_CELL_UNIT_INDEX_MAP[k as UnitType]]] as const)
.sort((a, b) => b[1] - a[1])
.filter((cell) => cell[1] > threshold)?.[0]?.[0] as UnitType | undefined;

const command = Object.keys(COMMAND_INDEX_MAP)
.map((k) => [k, input.data[i][j][COMMAND_INDEX_MAP[k as Unit['command']]]] as const)
const command = Object.keys(OUTPUT_CELL_COMMAND_INDEX_MAP)
.map((k) => [k, output.data[i][j][OUTPUT_CELL_COMMAND_INDEX_MAP[k as Unit['command']]]] as const)
.sort((a, b) => b[1] - a[1])[0]?.[0] as Unit['command'] | undefined;

if (type) {
Expand All @@ -76,8 +81,8 @@ export function modelInputToUnits(input: ModelInput): Unit[] {

for (let x = 0; x < MAX_COL; x++) {
for (let y = 0; y < MAX_ROW; y++) {
const inputRow = Math.min(Math.floor((y * input.rows) / MAX_ROW), INPUT_ROWS - 1);
const inputCol = Math.min(Math.floor((x * input.cols) / MAX_COL), INPUT_COLS - 1);
const inputRow = Math.min(Math.floor((y * output.rows) / MAX_ROW), INPUT_ROWS - 1);
const inputCol = Math.min(Math.floor((x * output.cols) / MAX_COL), INPUT_COLS - 1);
const unitData = inputMap[inputRow + ',' + inputCol];
if (unitData) {
units.push({
Expand Down
10 changes: 6 additions & 4 deletions games/masterplan/src/screens/battle/model/predict-browser.ts
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
import { predict } from './tf-browser';
import { Unit } from '../../designer/designer-types';
import { modelInputToUnits, unitsToModelInput } from './convert';
import { unitsToModelInput } from './convert-input';
import { modelOutputToUnits } from './convert-output';
import { countSoldiers, trimUnits } from './units-trim';
import { consolidateUnits } from './units-consolidate';
import { TerrainData } from '../game/terrain/terrain-generator';

export async function predictCounterPlan(playerPlan: Unit[]): Promise<Unit[]> {
const modelInput = unitsToModelInput(playerPlan);
export async function predictCounterPlan(playerPlan: Unit[], terrainData: TerrainData): Promise<Unit[]> {
const modelInput = unitsToModelInput(playerPlan, terrainData);
const modelOutput = await predict(modelInput);
return trimUnits(consolidateUnits(modelInputToUnits(modelOutput)), countSoldiers(playerPlan));
return trimUnits(consolidateUnits(modelOutputToUnits(modelOutput)), countSoldiers(playerPlan));
}
Loading

0 comments on commit bb53fbe

Please sign in to comment.