Skip to content

Commit

Permalink
modified view_level logic to take target_size + added timelapse code
Browse files Browse the repository at this point in the history
  • Loading branch information
shyamsn97 committed Feb 20, 2023
1 parent 897769d commit 74d9ff8
Show file tree
Hide file tree
Showing 4 changed files with 356 additions and 148 deletions.
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,11 @@ Architecture | Example Prompt Generations
![alt text](static/architecture.png) | ![alt text](static/prompt-samples.png)


MarioGPT is a finetuned GPT2 model (specifically, [distilgpt2](https://huggingface.co/distilgpt2)), that is trained on a subset Super Mario Bros and Super Mario Bros: The Lost Levels levels, provided by [The Video Game Level Corpus](https://github.com/TheVGLC/TheVGLC). MarioGPT is able to generate levels, guided by a simple text prompt. This generation is not perfect, but we believe this is a great first step more controllable and diverse level / environment generation.
MarioGPT is a finetuned GPT2 model (specifically, [distilgpt2](https://huggingface.co/distilgpt2)), that is trained on a subset Super Mario Bros and Super Mario Bros: The Lost Levels levels, provided by [The Video Game Level Corpus](https://github.com/TheVGLC/TheVGLC). MarioGPT is able to generate levels, guided by a simple text prompt. This generation is not perfect, but we believe this is a great first step more controllable and diverse level / environment generation. Forward generation:


![alt text](static/timelapse_0.gif)

Requirements
----
- python3.8+
Expand Down
45 changes: 32 additions & 13 deletions mario_gpt/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import math
import os
from typing import List, Union

Expand Down Expand Up @@ -27,10 +28,13 @@ def join_list_of_list(str_lists):
def view_level(level_tokens, tokenizer, flatten=False):
if flatten:
return tokenizer.batch_decode(level_tokens.detach().cpu().squeeze())
str_list = [
s.replace("<mask>", "Y")
for s in tokenizer.batch_decode(level_tokens.detach().cpu().view(-1, 14))
]
str_list = tokenizer.decode(level_tokens.detach().cpu()).replace("<mask>", "Y")
str_list = [str_list[i : i + 14] for i in range(0, len(str_list), 14)]
for i in range(len(str_list)):
length = len(str_list[i])
diff = 14 - length
if diff > 0:
str_list[i] = str_list[i] + "Y" * diff
return join_list_of_list(np.array(characterize(str_list)).T)


Expand All @@ -42,29 +46,31 @@ def is_flying_enemy(array, row, col):
return below == "-"


def char_array_to_image(array, chars2pngs):
def char_array_to_image(array, chars2pngs, target_size=None):
"""
Convert a 16-by-16 array of integers into a PIL.Image object
param: array: a 16-by-16 array of integers
"""
image = Image.new("RGB", (array.shape[1] * 16, array.shape[0] * 16))
if target_size is None:
image = Image.new("RGB", (array.shape[1] * 16, array.shape[0] * 16))
else:
image = Image.new("RGB", (target_size[1] * 16, target_size[0] * 16))
for row in range(array.shape[0]):
for col, char in enumerate(array[row]):
value = chars2pngs["-"]
# if char == "E":
# if is_flying_enemy(array, row, col):
# char = "F"
if char in chars2pngs:
value = chars2pngs[char]
else:
print(f"REPLACING {value}", (col, row))

image.paste(value, (col * 16, row * 16))
return image


def convert_level_to_png(
level: Union[str, torch.Tensor], tokenizer=None, tiles_dir: str = None
level: Union[str, torch.Tensor],
tokenizer=None,
tiles_dir: str = None,
target_size=None,
):
if isinstance(level, torch.Tensor):
level = view_level(level, tokenizer)
Expand All @@ -83,15 +89,28 @@ def convert_level_to_png(
"[": Image.open(f"{tiles_dir}/smb-tube-lower-left.png"),
"]": Image.open(f"{tiles_dir}/smb-tube-lower-right.png"),
"x": Image.open(f"{tiles_dir}/smb-path.png"), # self-created
"Y": Image.open(f"{tiles_dir}/Y.png"), # self-created
"Y": Image.fromarray(
np.uint8(np.zeros((16, 16)))
), # black square, # self-created
"N": Image.open(f"{tiles_dir}/N.png"), # self-created
"B": Image.open(f"{tiles_dir}/cannon_top.png"),
"b": Image.open(f"{tiles_dir}/cannon_bottom.png"),
"F": Image.open(f"{tiles_dir}/flying_koopa.png"),
}
levels = [list(s) for s in level]
arr = np.array(levels)
return char_array_to_image(arr, chars2pngs), arr, level
return char_array_to_image(arr, chars2pngs, target_size), arr, level


def generate_timelapse(level_tensor, mario_lm, interval: int = 1):
images = []
full_size = math.ceil(level_tensor.shape[-1] / 14)
for i in range(1, level_tensor.shape[-1], interval):
img = convert_level_to_png(
level_tensor[:i], mario_lm.tokenizer, target_size=(14, full_size)
)[0]
images.append(img)
return images


def save_level(level: List[str], filename: str):
Expand Down
Loading

0 comments on commit 74d9ff8

Please sign in to comment.