Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(closes #35) Changes for jax-metal #52

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion entropix/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def precompute_freqs_cis(dim: int, end: int, theta: float = 500000.0, use_scaled
freqs = apply_scaling(freqs)
t = jnp.arange(end, dtype=dtype)
freqs = jnp.outer(t, freqs)
return jnp.exp(1j * freqs)
return jnp.stack([jnp.cos(freqs), jnp.sin(freqs)], axis=-1)


def build_attn_mask(seqlen: int, start_pos: int) -> jax.Array:
Expand Down
14 changes: 8 additions & 6 deletions entropix/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,14 @@ def rms_norm(x: jax.Array, w: jax.Array, eps: float = 1e-6) -> jax.Array:
def apply_rotary_emb(xq: jax.Array, xk: jax.Array, freqs_cis: jax.Array, dtype: jnp.dtype = jnp.float32) -> Tuple[jax.Array, jax.Array]:
reshape_xq = xq.astype(jnp.float32).reshape(*xq.shape[:-1], -1, 2)
reshape_xk = xk.astype(jnp.float32).reshape(*xk.shape[:-1], -1, 2)
xq_ = jax.lax.complex(reshape_xq[..., 0], reshape_xq[..., 1])
xk_ = jax.lax.complex(reshape_xk[..., 0], reshape_xk[..., 1])
xq_out = xq_ * freqs_cis[None, :, None, :]
xk_out = xk_ * freqs_cis[None, :, None, :]
xq_out = jnp.stack((jnp.real(xq_out), jnp.imag(xq_out)), axis=-1).reshape(*xq_out.shape[:-1], -1)
xk_out = jnp.stack((jnp.real(xk_out), jnp.imag(xk_out)), axis=-1).reshape(*xk_out.shape[:-1], -1)
freqs_c = freqs_cis[None, :, None, :, 0]
freqs_s = freqs_cis[None, :, None, :, 1]
xq_re = reshape_xq[..., 0] * freqs_c - reshape_xq[..., 1] * freqs_s
xq_im = reshape_xq[..., 0] * freqs_s + reshape_xq[..., 1] * freqs_c
xq_out = jnp.stack((xq_re, xq_im), axis=-1).reshape(*xq_re.shape[:-1], -1)
xk_re = reshape_xk[..., 0] * freqs_c - reshape_xk[..., 1] * freqs_s
xk_im = reshape_xk[..., 0] * freqs_s + reshape_xk[..., 1] * freqs_c
xk_out = jnp.stack((xk_re, xk_im), axis=-1).reshape(*xk_re.shape[:-1], -1)
return xq_out.astype(dtype), xk_out.astype(dtype)

#@partial(jax.jit, static_argnames=("model_params", "cur_pos", "layer_idx"))
Expand Down
17 changes: 16 additions & 1 deletion entropix/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,21 @@ def multinomial_sample_one(probs_sort: jax.Array, key) -> jax.Array:
q = jax.random.exponential(key=key, shape=probs_sort.shape)
return jnp.argmax(probs_sort / q, axis=-1, keepdims=True).astype(jnp.int32)

# TODO: this should depend on whether device is metal.
# At time of writing, jax-metal did not support jax.lax.top_k.
use_jax_top_k = False

def _top_k(x, k):
if use_jax_top_k:
return jax.lax.top_k(x, k=k)
# jax.lax.top_k fails when using jax-metal, so reimplement it.
# You can't backprop through this version but it doesn't matter
# in the sampler.
sorted_indices = jnp.argsort(x, axis=-1)
indices = jnp.flip(sorted_indices[..., -k:], axis=-1)
values = jnp.take_along_axis(x, indices, axis=-1)
return values, indices

def _sample( logits: jax.Array, *, temperature: float | jax.Array, top_p: float | jax.Array, top_k: int | jax.Array, min_p: float | jax.Array,
key=jax.random.PRNGKey(1337),) -> jax.Array:
bsz = logits.shape[0]
Expand All @@ -33,7 +48,7 @@ def _sample( logits: jax.Array, *, temperature: float | jax.Array, top_p: float
logit = jnp.where(indices_to_remove, jnp.full_like(logit, float('-inf')), logit)

# Apply top-k sampling
top_k_probs, top_k_indices = jax.lax.top_k(probs, k=top_k)
top_k_probs, top_k_indices = _top_k(probs, k=top_k)
probs_sort = jnp.flip(top_k_probs, axis=-1)
probs_idx = jnp.flip(top_k_indices, axis=-1)
probs_sum = jnp.cumsum(probs_sort, axis=-1)
Expand Down
10 changes: 7 additions & 3 deletions entropix/weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,14 @@ def load_weights(ckpt_dir: Path, n_layers: int = 16):
w = {}
layer_weights = []
try:
device = jax.devices("gpu")[0]
device = jax.devices("METAL")[0]
except RuntimeError:
print("GPU not found. Using CPU instead.")
device = jax.devices("cpu")[0]
print("Metal not found, trying GPU.")
try:
device = jax.devices("gpu")[0]
except RuntimeError:
print("GPU not found. Using CPU instead.")
device = jax.devices("cpu")[0]
for file in ckpt_dir.glob("*.npy"):
name = '.'.join(str(file).split('/')[-1].split('.')[:-1])
weight = jnp.load(file=file, mmap_mode='r', allow_pickle=True)
Expand Down