diff --git a/python/python/embed_anything/__init__.py b/python/python/embed_anything/__init__.py index 9770cb5..4ce0700 100644 --- a/python/python/embed_anything/__init__.py +++ b/python/python/embed_anything/__init__.py @@ -120,6 +120,7 @@ """ from ._embed_anything import * from .vectordb import * +import platform import os import onnxruntime import glob @@ -128,11 +129,19 @@ if path is None: print("onnxruntime is not installed. Install it using `pip install onnxruntime`") - else: - dylib_path = glob.glob(os.path.join(path, "libonnxruntime.so*")) - os.environ["ORT_DYLIB_PATH"] = dylib_path[0] + if platform.system() == "Windows": + # For Windows, look for DLL files + dylib_path = glob.glob(os.path.join(path, "onnxruntime.dll")) + else: + # For Linux, look for shared object files + dylib_path = glob.glob(os.path.join(path, "libonnxruntime.so*")) + + if dylib_path: + os.environ["ORT_DYLIB_PATH"] = dylib_path[0] + else: + print("onnxruntime dynamic library not found.") __doc__ = _embed_anything.__doc__ if hasattr(_embed_anything, "__all__"): - __all__ = _embed_anything.__all__ + __all__ = _embed_anything.__all__ \ No newline at end of file