Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Comparison to "Threads" #430

Open
hyanwong opened this issue Aug 25, 2024 · 1 comment
Open

Comparison to "Threads" #430

hyanwong opened this issue Aug 25, 2024 · 1 comment

Comments

@hyanwong
Copy link
Member

hyanwong commented Aug 25, 2024

I thought it would be interesting to compare how tsinfer+tsdate do against the new "Threads" program (https://pypi.org/project/threads-arg/0.1.0/).

As a test, I have a stdpopsim model with 2 chimp species (3 populations) and some selective sweeps in 2 of the populations:

Screenshot 2024-08-25 at 17 20 40

Here are the edge plots from the original vs tsinfer+tsdate and Threads on 5Mb of genome: the 2 recent selective sweeps (in western at 1/3rd along the genome, and bonobo at 2/3rds) are obvious in the plot of the true edges. Threads seems a little better in picking up the recent sweeps, but maybe picks up less of the demographic banding at the top (although to be fair, I told it there was a fixed population size of 100,000 haploid genomes). For this scale of data (120 genomes), Threads is highly faster than tsinfer (and it is relatively faster if the genome length increases, but I suspect we would see a relative slow down if the sample size increased to very large numbers)

Screenshot 2024-08-25 at 17 24 12
Click for code
import os
import subprocess
import tempfile

import stdpopsim
import msprime
import tskit
import numpy as np
import pandas as pd
import tszip
import matplotlib.pyplot as plt
import matplotlib as mpl
import tsinfer
import tsdate

def sweep_and_demography_Pan(
    sequence_length,
    sample_sizes,  # could be a single number, or a dict of pop_name: size
    sweep_params=None,
    chrom="chr3",
    random_seed=123,
):
    """
    Make an example of a 3-population reasonably complex demography with a few more recent
    selective sweeps
    """
    species = stdpopsim.get_species("PanTro")
    model = species.get_demographic_model("BonoboGhost_4K19")
    msprime_demography = model.model
    contig = species.get_contig("chr3", mutation_rate=model.mutation_rate)
    ratemap = contig.recombination_map.slice(right=L, trim=True)

    try:
        sample_sizes = {
            name: int(sample_sizes)
            for name, info in msprime_demography.items()
            # Don't sample from the ghost population
            if info.extra_metadata['sampling_time'] is not None
        }
    except TypeError:
        pass
    G = 4000  # A time ago in generations: we assume populations from time 0..G are isolated and of constant size
    
    # Make independent populations, some with selective sweeps
    independent_pop_ts = []
    for name, pop in model.model.items():
        if name in sample_sizes:
            Ne = pop.initial_size
            demog = msprime.Demography()
            demog.add_population(name=name, initial_size=Ne)
            if name in sweep_params:
                p = 1 / (2 * Ne)
                freqs = {"start_frequency": p, "end_frequency": 1 - p, "dt": 1 / (40 * Ne)}
                sweep_model = msprime.SweepGenicSelection(**freqs, **sweep_params[name])
                models = (sweep_model, msprime.StandardCoalescent())
                print(f"Adding {name} population to demographic model, sweep at {int(sweep_params[name]['position'])}bp, selection coefficient s={sweep_params[name]['s']}")
            else:
                models = msprime.StandardCoalescent()
                print(f"Adding {name} population to demographic model, neutral")
            independent_pop_ts.append(msprime.sim_ancestry(
                sample_sizes[name],
                model=models,
                demography=demog,
                recombination_rate=ratemap,
                sequence_length=sequence_length,
                end_time=G,
                random_seed=123,
            ))
    combined_ts = independent_pop_ts[0]
    for ts in independent_pop_ts[1:]:
        combined_ts = combined_ts.union(ts, node_mapping=np.full(ts.num_nodes, tskit.NULL))
    
    # Now recapitate: initial_state uses the population names in the combined_ts to figure out which are which
    ts = msprime.sim_ancestry(initial_state=combined_ts, demography=msprime_demography, random_seed=random_seed).simplify()
    return msprime.sim_mutations(ts, rate=model.mutation_rate, random_seed=random_seed), model, ratemap


L = 5e6  # Simulate 5 Mb
sweep_params = {
    "western": {"position": L//3, "s": 0.1},
    "bonobo": {"position": (2*L)//3, "s": 0.05},
}

ts, model, ratemap = sweep_and_demography_Pan(
    sequence_length=L,
    sample_sizes=20,
    sweep_params=sweep_params
)

print(f"Simulated {ts.num_sites} sites")

def run_threads(input_ts, ratemap, demography):
    "Run Threads: currently 'demography' is hard-coded in as a hack"
    # remove multiallelics
    ts = input_ts.delete_sites([s.id for s in input_ts.sites() if len(s.alleles) != 2])

    # Make .pgen & .pvar files
    n_dip_indv = int(ts.num_samples / 2)
    indv_names = [f"tsk_{i}indv" for i in range(n_dip_indv)]
    with tempfile.TemporaryDirectory() as tmpdirname:
        tmp_fn_prefix = os.path.join(tmpdirname, "tmp")
        with open(f"{tmp_fn_prefix}.vcf", "wt") as vcf_file:
            ts.write_vcf(vcf_file, individual_names=indv_names)

        subprocess.call(["./plink2", "--vcf", f"{tmp_fn_prefix}.vcf", "--out", tmp_fn_prefix])

        # Make a map with a position for each SNP
        df = pd.DataFrame({
            "chr": np.repeat("Chr1", ts.num_sites),
            "SNP": np.arange(ts.num_sites),
            "cM": ratemap.get_cumulative_mass(ts.sites_position) * 100,
            "bp": ts.sites_position.astype(int),
        })
        df.to_csv(f"{tmp_fn_prefix}.map.gz", sep="\t", index=False, header=False)

        # Hack a demography file
        times = np.array([0, 1e6])
        diploid_size = np.array([50_000, 50_000])
        
        df = pd.DataFrame({"gens_ago": times.astype(int), "haploid_Ne": diploid_size.astype(int) * 2})
        df.to_csv(f"{tmp_fn_prefix}.demo", sep="\t", index=False, header=False)

        subprocess.call([
            "threads",
            "infer",
            "--pgen", f"{tmp_fn_prefix}.pgen",
            "--map_gz", f"{tmp_fn_prefix}.map.gz",
            "--demography", f"{tmp_fn_prefix}.demo",
            "--out", f"{tmp_fn_prefix}.threads",
        ])
        subprocess.call([
            "threads",
            "convert",
            "--threads", f"{tmp_fn_prefix}.threads",
            "--tsz", f"{tmp_fn_prefix}.tsz",
        ])
        return tszip.decompress(f"{tmp_fn_prefix}.tsz")

threads_ts = run_threads(ts, ratemap, None)


tsinfer_ts = tsinfer.infer(
    tsinfer.SampleData.from_tree_sequence(ts),
    num_threads=6,
    progress_monitor=True,
)
tsinfer_tsdate_ts = tsdate.date(
    tsdate.preprocess_ts(tsinfer_ts),
    mutation_rate=model.mutation_rate,
    rescaling_intervals=100,
)
# reinfer
tsinfer2_ts = tsinfer.infer(
    tsinfer.SampleData.from_tree_sequence(tsinfer_tsdate_ts, use_sites_time=True),
    num_threads=6,
    progress_monitor=True,
)
tsinfer2_tsdate_ts = tsdate.date(
    tsdate.preprocess_ts(tsinfer2_ts),
    mutation_rate=model.mutation_rate,
    rescaling_intervals=100,
)


def edge_plot(plot_ts, ax):
    tm = plot_ts.nodes_time[plot_ts.edges_parent]
    ax.add_collection(
        mpl.collections.LineCollection(
            np.array([[plot_ts.edges_left, plot_ts.edges_right], [tm, tm]]).T,
            alpha=0.2
        )
    )
    ax.autoscale()
    ax.margins(0)
    ax.set_yscale("log")


fig, (ax_orig, ax_tsinfer, ax_threads) = plt.subplots(3, 1, figsize=(15, 15), sharex=True)
edge_plot(ts, ax_orig)
ax_orig.set_title(f"True edges & times ({ts.num_edges} edges)")
edge_plot(tsinfer2_tsdate_ts, ax_tsinfer)
ax_tsinfer.set_title(f"Tsinfer + tsdate ({tsinfer2_tsdate_ts.num_edges} edges)")
edge_plot(threads_ts.simplify(), ax_threads)
ax_threads.set_title(f"Threads {threads_ts.simplify().num_edges} edges");
ax_threads.set_xlabel("Genome position")
@hyanwong
Copy link
Member Author

Incidentally, here's how tsinfer+tsdate does if we feed in the true times

Screenshot 2024-08-25 at 17 35 37

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant