Skip to content

Commit

Permalink
refactor anaconda
Browse files Browse the repository at this point in the history
  • Loading branch information
sfc-gh-pczajka committed Apr 5, 2024
1 parent 0dfe66f commit 83a73f7
Showing 1 changed file with 25 additions and 21 deletions.
46 changes: 25 additions & 21 deletions src/snowflake/cli/plugins/snowpark/package/anaconda.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,27 @@ def __init__(self, packages: Dict[str, AnacondaPackageData]):
"""
self._packages = packages

@classmethod
def from_snowflake(cls):
try:
response = requests.get(AnacondaChannel.snowflake_channel_url)
response.raise_for_status()
packages = {}
for key, package in response.json()["packages"].items():
if not (version := package.get("version")):
continue
package_name = package.get("name", key)
standardized_name = Requirement.standardize_name(package_name)
packages[standardized_name] = AnacondaPackageData(
snowflake_name=package_name, versions={version}
)
return cls(packages)

except HTTPError as err:
raise ClickException(
f"Accessing Snowflake Anaconda channel failed. Reason {err}"
)

def is_package_available(
self, package: Requirement, skip_version_check: bool = False
) -> bool:
Expand Down Expand Up @@ -90,28 +111,11 @@ def package_versions(self, package: Requirement) -> List[str]:
"""Returns list of available versions of the package."""
if package.name not in self._packages:
return []
return list(sorted(self._packages[package.name].versions, reverse=True))

@classmethod
def from_snowflake(cls):
package_data = self._packages[package.name]
try:
response = requests.get(AnacondaChannel.snowflake_channel_url)
response.raise_for_status()
packages = {}
for key, package in response.json()["packages"].items():
if not (version := package.get("version")):
continue
package_name = package.get("name", key)
standardized_name = Requirement.standardize_name(package_name)
packages[standardized_name] = AnacondaPackageData(
snowflake_name=package_name, versions={version}
)
return cls(packages)

except HTTPError as err:
raise ClickException(
f"Accessing Snowflake Anaconda channel failed. Reason {err}"
)
return list(str(x) for x in sorted(package_data.iter_versions()))
except InvalidVersion:
return list(sorted(package_data.versions, reverse=True))

def filter_anaconda_packages(
self, packages: List[Requirement], skip_version_check: bool = False
Expand Down

0 comments on commit 83a73f7

Please sign in to comment.