Skip to content

Commit

Permalink
Added support for HQ-SAM (#161)
Browse files Browse the repository at this point in the history
* Revise download_checkpoint function

* Add hq_sam module
  • Loading branch information
giswqs authored Aug 13, 2023
1 parent eda8e55 commit ef1c4dd
Show file tree
Hide file tree
Showing 11 changed files with 940 additions and 45 deletions.
2 changes: 0 additions & 2 deletions docs/examples/arcgis.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,6 @@
"source": [
"sam = SamGeo(\n",
" model_type=\"vit_h\",\n",
" checkpoint=\"sam_vit_h_4b8939.pth\",\n",
" sam_kwargs=None,\n",
")"
]
Expand Down Expand Up @@ -418,7 +417,6 @@
"source": [
"sam = SamGeo(\n",
" model_type=\"vit_h\",\n",
" checkpoint=checkpoint,\n",
" sam_kwargs=sam_kwargs,\n",
")"
]
Expand Down
2 changes: 0 additions & 2 deletions docs/examples/automatic_mask_generator.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,6 @@
"source": [
"sam = SamGeo(\n",
" model_type=\"vit_h\",\n",
" checkpoint='sam_vit_h_4b8939.pth',\n",
" sam_kwargs=None,\n",
")"
]
Expand Down Expand Up @@ -287,7 +286,6 @@
"source": [
"sam = SamGeo(\n",
" model_type=\"vit_h\",\n",
" checkpoint=checkpoint,\n",
" sam_kwargs=sam_kwargs,\n",
")"
]
Expand Down
1 change: 0 additions & 1 deletion docs/examples/box_prompts.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,6 @@
"source": [
"sam = SamGeo(\n",
" model_type=\"vit_h\",\n",
" checkpoint=\"sam_vit_h_4b8939.pth\",\n",
" automatic=False,\n",
" sam_kwargs=None,\n",
")"
Expand Down
1 change: 0 additions & 1 deletion docs/examples/input_prompts.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,6 @@
"source": [
"sam = SamGeo(\n",
" model_type=\"vit_h\",\n",
" checkpoint=\"sam_vit_h_4b8939.pth\",\n",
" automatic=False,\n",
" sam_kwargs=None,\n",
")"
Expand Down
3 changes: 3 additions & 0 deletions docs/hq_sam.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# hq_sam module

::: samgeo.hq_sam
1 change: 1 addition & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,4 +60,5 @@ nav:
- API Reference:
- common module: common.md
- samgeo module: samgeo.md
- hq_sam module: hq_sam.md
- text_sam module: text_sam.md
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ gdown
xyzservices
pyproj
leafmap
localtileserver
localtileserver
timm
1 change: 1 addition & 0 deletions requirements_dev.txt
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
rio-cogeo
segment-anything-hq
65 changes: 64 additions & 1 deletion samgeo/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,70 @@ def download_file(
return os.path.abspath(output)


def download_checkpoint(url=None, output=None, overwrite=False, **kwargs):
def download_checkpoint(model_type="vit_h", checkpoint_dir=None, hq=False):
"""Download the SAM model checkpoint.
Args:
model_type (str, optional): The model type. Can be one of ['vit_h', 'vit_l', 'vit_b'].
Defaults to 'vit_h'. See https://bit.ly/3VrpxUh for more details.
checkpoint_dir (str, optional): The checkpoint_dir directory. Defaults to None, "~/.cache/torch/hub/checkpoints".
hq (bool, optional): Whether to use HQ-SAM model (https://github.com/SysCV/sam-hq). Defaults to False.
"""

if not hq:
model_types = {
"vit_h": {
"name": "sam_vit_h_4b8939.pth",
"url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth",
},
"vit_l": {
"name": "sam_vit_l_0b3195.pth",
"url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth",
},
"vit_b": {
"name": "sam_vit_b_01ec64.pth",
"url": "https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth",
},
}
else:
model_types = {
"vit_h": {
"name": "sam_hq_vit_h.pth",
"url": "https://drive.google.com/file/d/1qobFYrI4eyIANfBSmYcGuWRaSIXfMOQ8/view?usp=sharing",
},
"vit_l": {
"name": "sam_hq_vit_l.pth",
"url": "https://drive.google.com/file/d/1Uk17tDKX1YAKas5knI4y9ZJCo0lRVL0G/view?usp=sharing",
},
"vit_b": {
"name": "sam_hq_vit_b.pth",
"url": "https://drive.google.com/file/d/11yExZLOve38kRZPfRx_MRxfIAKmfMY47/view?usp=sharing",
},
"vit_tiny": {
"name": "sam_hq_vit_tiny.pth",
"url": "https://huggingface.co/lkeab/hq-sam/resolve/main/sam_hq_vit_tiny.pth",
},
}

if model_type not in model_types:
raise ValueError(
f"Invalid model_type: {model_type}. It must be one of {', '.join(model_types)}"
)

if checkpoint_dir is None:
checkpoint_dir = os.environ.get(
"TORCH_HOME", os.path.expanduser("~/.cache/torch/hub/checkpoints")
)

checkpoint = os.path.join(checkpoint_dir, model_types[model_type]["name"])
if not os.path.exists(checkpoint):
print(f"Model checkpoint for {model_type} not found.")
url = model_types[model_type]["url"]
download_file(url, checkpoint)
return checkpoint


def download_checkpoint_legacy(url=None, output=None, overwrite=False, **kwargs):
"""Download a checkpoint from URL. It can be one of the following: sam_vit_h_4b8939.pth, sam_vit_l_0b3195.pth, sam_vit_b_01ec64.pth.
Args:
Expand Down
Loading

0 comments on commit ef1c4dd

Please sign in to comment.