Skip to content

Commit

Permalink
Change checkpoints download URLs (#259)
Browse files Browse the repository at this point in the history
  • Loading branch information
giswqs authored Mar 30, 2024
1 parent d0e7162 commit 116e515
Show file tree
Hide file tree
Showing 4 changed files with 148 additions and 10 deletions.
4 changes: 1 addition & 3 deletions docs/examples/maxar_open_data.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,7 @@
},
"outputs": [],
"source": [
"url = (\n",
" \"https://drive.google.com/file/d/1jIIC5hvSPeJEC0fbDhtxVWk2XV9AxsQD/view?usp=sharing\"\n",
")"
"url = \"https://github.com/opengeos/datasets/releases/download/raster/Derna_sample.tif\""
]
},
{
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ leafmap
localtileserver
matplotlib
opencv-python
patool
pycocotools
pyproj
rasterio
Expand Down
149 changes: 144 additions & 5 deletions samgeo/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,19 +209,22 @@ def download_checkpoint(model_type="vit_h", checkpoint_dir=None, hq=False):
model_types = {
"vit_h": {
"name": "sam_hq_vit_h.pth",
"url": "https://drive.google.com/file/d/1qobFYrI4eyIANfBSmYcGuWRaSIXfMOQ8/view?usp=sharing",
"url": [
"https://github.com/opengeos/datasets/releases/download/models/sam_hq_vit_h.zip",
"https://github.com/opengeos/datasets/releases/download/models/sam_hq_vit_h.z01",
],
},
"vit_l": {
"name": "sam_hq_vit_l.pth",
"url": "https://drive.google.com/file/d/1Uk17tDKX1YAKas5knI4y9ZJCo0lRVL0G/view?usp=sharing",
"url": "https://github.com/opengeos/datasets/releases/download/models/sam_hq_vit_l.pth",
},
"vit_b": {
"name": "sam_hq_vit_b.pth",
"url": "https://drive.google.com/file/d/11yExZLOve38kRZPfRx_MRxfIAKmfMY47/view?usp=sharing",
"url": "https://github.com/opengeos/datasets/releases/download/models/sam_hq_vit_b.pth",
},
"vit_tiny": {
"name": "sam_hq_vit_tiny.pth",
"url": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_tiny.pth",
"url": "https://github.com/opengeos/datasets/releases/download/models/sam_hq_vit_tiny.pth",
},
}

Expand All @@ -239,7 +242,10 @@ def download_checkpoint(model_type="vit_h", checkpoint_dir=None, hq=False):
if not os.path.exists(checkpoint):
print(f"Model checkpoint for {model_type} not found.")
url = model_types[model_type]["url"]
download_file(url, checkpoint)
if isinstance(url, str):
download_file(url, checkpoint)
elif isinstance(url, list):
download_files(url, checkpoint_dir, multi_part=True)
return checkpoint


Expand Down Expand Up @@ -2987,3 +2993,136 @@ def merge_rasters(
dstNodata=output_nodata,
options=output_options,
)


def extract_archive(archive, outdir=None, **kwargs):
"""
Extracts a multipart archive.
This function uses the patoolib library to extract a multipart archive.
If the patoolib library is not installed, it attempts to install it.
If the archive does not end with ".zip", it appends ".zip" to the archive name.
If the extraction fails (for example, if the files already exist), it skips the extraction.
Args:
archive (str): The path to the archive file.
outdir (str): The directory where the archive should be extracted.
**kwargs: Arbitrary keyword arguments for the patoolib.extract_archive function.
Returns:
None
Raises:
Exception: An exception is raised if the extraction fails for reasons other than the files already existing.
Example:
files = ["sam_hq_vit_tiny.zip", "sam_hq_vit_tiny.z01", "sam_hq_vit_tiny.z02", "sam_hq_vit_tiny.z03"]
base_url = "https://github.com/opengeos/datasets/releases/download/models/"
urls = [base_url + f for f in files]
leafmap.download_files(urls, out_dir="models", multi_part=True)
"""
try:
import patoolib
except ImportError:
install_package("patool")
import patoolib

if not archive.endswith(".zip"):
archive = archive + ".zip"

if outdir is None:
outdir = os.path.dirname(archive)

try:
patoolib.extract_archive(archive, outdir=outdir, **kwargs)
except Exception as e:
print("The unzipped files might already exist. Skipping extraction.")
return


def download_files(
urls,
out_dir=None,
filenames=None,
quiet=False,
proxy=None,
speed=None,
use_cookies=True,
verify=True,
id=None,
fuzzy=False,
resume=False,
unzip=True,
overwrite=False,
subfolder=False,
multi_part=False,
):
"""Download files from URLs, including Google Drive shared URL.
Args:
urls (list): The list of urls to download. Google Drive URL is also supported.
out_dir (str, optional): The output directory. Defaults to None.
filenames (list, optional): Output filename. Default is basename of URL.
quiet (bool, optional): Suppress terminal output. Default is False.
proxy (str, optional): Proxy. Defaults to None.
speed (float, optional): Download byte size per second (e.g., 256KB/s = 256 * 1024). Defaults to None.
use_cookies (bool, optional): Flag to use cookies. Defaults to True.
verify (bool | str, optional): Either a bool, in which case it controls whether the server's TLS certificate is verified, or a string, in which case it must be a path to a CA bundle to use. Default is True.. Defaults to True.
id (str, optional): Google Drive's file ID. Defaults to None.
fuzzy (bool, optional): Fuzzy extraction of Google Drive's file Id. Defaults to False.
resume (bool, optional): Resume the download from existing tmp file if possible. Defaults to False.
unzip (bool, optional): Unzip the file. Defaults to True.
overwrite (bool, optional): Overwrite the file if it already exists. Defaults to False.
subfolder (bool, optional): Create a subfolder with the same name as the file. Defaults to False.
multi_part (bool, optional): If the file is a multi-part file. Defaults to False.
Examples:
files = ["sam_hq_vit_tiny.zip", "sam_hq_vit_tiny.z01", "sam_hq_vit_tiny.z02", "sam_hq_vit_tiny.z03"]
base_url = "https://github.com/opengeos/datasets/releases/download/models/"
urls = [base_url + f for f in files]
leafmap.download_files(urls, out_dir="models", multi_part=True)
"""

if out_dir is None:
out_dir = os.getcwd()

if filenames is None:
filenames = [None] * len(urls)

filepaths = []
for url, output in zip(urls, filenames):
if output is None:
filename = os.path.join(out_dir, os.path.basename(url))
else:
filename = os.path.join(out_dir, output)

filepaths.append(filename)
if multi_part:
unzip = False

download_file(
url,
filename,
quiet,
proxy,
speed,
use_cookies,
verify,
id,
fuzzy,
resume,
unzip,
overwrite,
subfolder,
)

if multi_part:
archive = os.path.splitext(filename)[0] + ".zip"
out_dir = os.path.dirname(filename)
extract_archive(archive, out_dir)

for file in filepaths:
os.remove(file)
4 changes: 2 additions & 2 deletions samgeo/fast_sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@ def __init__(self, model="FastSAM-x.pt", **kwargs):
)

models = {
"FastSAM-x.pt": "https://drive.google.com/file/d/1m1sjY4ihXBU1fZXdQ-Xdj-mDltW-2Rqv/view?usp=sharing",
"FastSAM-s.pt": "https://drive.google.com/file/d/10XmSj6mmpmRb8NhXbtiuO9cTTBwR_9SV/view?usp=sharing",
"FastSAM-x.pt": "https://github.com/opengeos/datasets/releases/download/models/FastSAM-x.pt",
"FastSAM-s.pt": "https://github.com/opengeos/datasets/releases/download/models/FastSAM-s.pt",
}

if model not in models:
Expand Down

0 comments on commit 116e515

Please sign in to comment.