From cb57b5f84544b63ec0c81fbe1bc6014fc4116807 Mon Sep 17 00:00:00 2001 From: AirportR Date: Fri, 1 Mar 2024 15:45:54 +0800 Subject: [PATCH] :sparkles: Automatically download the core. --- botmodule/init_bot.py | 80 ++++++++++++++++++++++++++++++++++------ utils/cleaner.py | 46 +++++++++++++++++++++-- utils/collector.py | 86 +++++++++++++++++++++++++++++++++++++++++-- 3 files changed, 194 insertions(+), 18 deletions(-) diff --git a/botmodule/init_bot.py b/botmodule/init_bot.py index 08eb887b..8dd6bca1 100644 --- a/botmodule/init_bot.py +++ b/botmodule/init_bot.py @@ -1,27 +1,19 @@ import asyncio import os import sys +from pathlib import Path from subprocess import check_output from loguru import logger -from utils.cleaner import config +from utils.cleaner import config, unzip_targz, unzip from utils import HOME_DIR +from utils.collector import get_latest_tag, Download, DownloadError admin = config.getAdmin() # 管理员 config.add_user(admin) # 管理员同时也是用户 config.reload() -# def check_permission(): -# if sys.platform != "win32": -# try: -# status = os.system(f"chmod +x {clash_path}") -# if status != 0: -# raise OSError(f"Failed to execute command: chmod +x {clash_path}") -# except OSError as o: -# print(o) - - def check_args(): import argparse help_text_socks5 = "设置socks5代理,bot代理使用的这个\n格式--> host:端口:用户名:密码\t用户名和密码可省略" @@ -76,6 +68,11 @@ def check_args(): class Init: + repo_owner = "AirportR" + repo_name = "FullTclash" + ftcore_owner = repo_owner + ftcore_name = "FullTCore" + @staticmethod def init_emoji(): emoji_source = config.config.get('emoji', {}).get('emoji-source', 'TwemojiLocalSource') @@ -128,6 +125,66 @@ def init_commit_string(): _latest_version_hash = "Unknown" return _latest_version_hash + @staticmethod + def init_proxy_client(): + """ + 自动下载代理客户端FullTCore + """ + if config.get_clash_path() is not None: + return + import platform + loop = asyncio.get_event_loop() + tag = loop.run_until_complete(get_latest_tag(Init.ftcore_owner, Init.ftcore_name)) + tag2 = tag[1:] if tag[0] == "v" else tag + arch = platform.machine().lower() + suffix = ".tar.gz" + if sys.platform.startswith('linux'): + pf = "linux" + elif sys.platform.startswith('darwin'): + pf = "darwin" + elif sys.platform.startswith('win32'): + pf = "windows" + suffix = ".zip" + else: + logger.info("无法找到FullTCore在当前平台的预编译文件,请自行下载。") + return + + # https://github.com/AirportR/FullTCore/releases/download/v1.3-meta/FullTCore_1.3-meta_windows_amd64.zip + base_url = f"https://github.com/{Init.ftcore_owner}/{Init.ftcore_name}" + + download_url = base_url + f"/releases/download/{tag}/FullTCore_{tag2}_{pf}_{arch}{suffix}" + savename = download_url.split("/")[-1] + logger.info(f"正在自动为您下载最新版本({tag})的FullTCore: {download_url}") + savepath = Path(HOME_DIR).joinpath("bin").absolute() + saved_file = savepath.joinpath(savename) + + try: + loop.run_until_complete(Download(download_url, savepath, savename).dowload(proxy=config.get_proxy())) + except DownloadError: + logger.info("无法找到FullTCore在当前平台的预编译文件,请自行下载。") + return + except (OSError, Exception) as e: + logger.info(str(e)) + return + + if suffix.endswith("zip"): + unzip_result = unzip(saved_file, savepath) + elif suffix.endswith("tar.gz"): + unzip_result = unzip_targz(saved_file, savepath) + else: + unzip_result = False + if unzip_result: + if pf == "windows": + corename = Init.ftcore_name + ".exe" + else: + corename = Init.ftcore_name + proxy_path = str(savepath.joinpath(corename).as_posix()) + clash_cfg = config.config.get('clash', {}) + clash_cfg = clash_cfg if isinstance(clash_cfg, dict) else {} + clash_cfg['path'] = proxy_path + config.yaml['clash'] = clash_cfg + config.reload() + def check_init(): if config.getClashBranch() == 'meta': @@ -135,6 +192,7 @@ def check_init(): Init.init_emoji() Init.init_dir() Init.init_permission() + Init.init_proxy_client() def check_version() -> str: diff --git a/utils/cleaner.py b/utils/cleaner.py index c0c6b722..8fed105a 100644 --- a/utils/cleaner.py +++ b/utils/cleaner.py @@ -1,8 +1,12 @@ import asyncio +import gzip import importlib import os +import tarfile +import shutil from copy import deepcopy +from pathlib import Path from typing import Union, List import socket @@ -979,9 +983,7 @@ def get_clash_path(self): try: return self.config['clash']['path'] except KeyError: - logger.warning("为减轻项目文件大小从3.6.5版本开始,不再默认提供代理客户端二进制文件,请自行前往以下网址获取: \n" - "https://github.com/AirportR/FullTCore/releases") - raise ValueError("找不到代理客户端二进制文件") + return None def get_sub(self, subname: str = None): """ @@ -1810,3 +1812,41 @@ def batch_ipcu(host: list): else: ipcu.append("N/A") return ipcu + + +def unzip_targz(input_file: str | Path, output_path: str | Path) -> List[str]: + """解压 tar.gz 文件, 返回解压后的文件名称列表""" + unzip_files = [] + try: + temp_path = input_file.rstrip(".gz") + path = Path(input_file) + path2 = Path(output_path) + with gzip.open(path, 'rb') as f_in, open(temp_path, 'wb') as f_out: + f_out.write(f_in.read()) + + with tarfile.open(temp_path) as tar: + for member in tar.getmembers(): + member_path = path2.joinpath(member.name) + logger.info("正在解压: " + str(member_path)) + f = tar.extractfile(member) + with open(member_path, 'wb') as ff: + ff.write(f.read()) + f.close() + unzip_files.append(str(member_path)) + archive_file = Path(temp_path).absolute() + archive_file.unlink() + + except Exception as e: + logger.error(str(e)) + finally: + return unzip_files + + +def unzip(input_file: str | Path, output_path: str | Path): + """解压zip文件""" + try: + output_path = Path(output_path) if isinstance(output_path, Path) else output_path + shutil.unpack_archive(str(input_file), output_path, format='zip') + return True + except shutil.Error: + return False diff --git a/utils/collector.py b/utils/collector.py index 0102059d..df9915a8 100644 --- a/utils/collector.py +++ b/utils/collector.py @@ -1,18 +1,22 @@ import asyncio import ssl import time +from datetime import datetime +from pathlib import Path -from typing import List +from typing import List, Union from urllib.parse import quote import aiohttp import async_timeout +from aiohttp import ClientSession from aiohttp.client_exceptions import ClientConnectorError, ContentTypeError from aiohttp_socks import ProxyConnector, ProxyConnectionError from loguru import logger from utils import cleaner +from utils.cleaner import config """ 这是整个项目最为核心的功能模块之一 —> 采集器。它负责从网络上采集想要的数据。到现在,已经设计了: @@ -27,7 +31,6 @@ 如果你想自己添加一个流媒体测试项,建议查看 ./resources/dos/新增流媒体测试项指南.md """ -config = cleaner.config addon = cleaner.addon media_items = config.get_media_item() proxies = config.get_proxy() # 代理 @@ -35,7 +38,7 @@ def reload_config(media: list = None): - global config, proxies, media_items + global proxies, media_items config.reload(issave=False) proxies = config.get_proxy() media_items = config.get_media_item() @@ -56,12 +59,74 @@ async def status(self, url, proxy=None): return response.status async def fetch(self, url, proxy=None): - with async_timeout.timeout(10): + async with async_timeout.timeout(10): async with aiohttp.ClientSession(headers=self._headers) as session: async with session.get(url, proxy=proxy) as response: return await response.content.read() +class DownloadError(aiohttp.ClientError): + """下载出错抛出的异常""" + + +class Download(BaseCollector): + def __init__(self, url: str = None, savepath: Union[str, Path] = None, savename: str = None): + _formatted_now = f'{datetime.now():%Y-%m-%dT%H-%M-%S}' + self.url = url + self.savepath = savepath + self.savename = savename if savename is not None else f"download-{_formatted_now}" + super().__init__() + + async def download_common(self, url: str = None, savepath: Union[str, Path] = None, **kwargs) -> bool: + """ + 通用下载函数 + """ + url = url or self.url + savepath = savepath or self.savepath or "." + savepath = str(savepath) + savepath = savepath if savepath.endswith("/") else savepath + "/" + savename = self.savename.lstrip("/") + write_path = savepath + savename + from utils.cleaner import geturl + url = geturl(url) + if not url: + raise DownloadError(f"这不是有效的URL: {url}") + try: + async with ClientSession(headers=self._headers) as session: + async with session.get(url, **kwargs) as resp: + if resp.status == 200: + content_leagth = resp.content_length if resp.content_length else 10 * 1024 * 1024 + length = 0 + with open(write_path, 'wb') as f: + while True: + chunk = await resp.content.read(1024) + length += len(chunk) + # 计算进度条长度 + percent = '=' * int(length * 100 / content_leagth) + spaces = ' ' * (100 - len(percent)) + print(f"\r[{percent}{spaces}] {length} B", end="") + if not chunk: + break + + f.write(chunk) + l2 = float(length) / 1024 / 1024 + l2 = round(l2, 2) + spath = str(Path(savepath).absolute()) + print(f"\r[{'=' * 100}] 共下载{length}B ({l2}MB)" + f"已保存到 {spath}") + elif resp.status == 404: + raise DownloadError(f"找不到资源: {resp.status}==>\t{url}") + return True + except (aiohttp.ClientError, OSError) as e: + raise DownloadError("Download failed") from e + + async def dowload(self, url: str = None, savepath: Union[str, Path] = None, **kwargs) -> bool: + """ + 执行下载操作 + """ + return await self.download_common(url, savepath, **kwargs) + + class IPCollector: """ GEOIP 测试采集类 @@ -503,5 +568,18 @@ async def start(self, host: str, port: int, proxy=None): return self.info +async def get_latest_tag(username, repo): + import re + url = f'https://github.com/{username}/{repo}/tags' + async with ClientSession() as session: + async with session.get(url, proxy=config.get_proxy(), timeout=10) as r: + text = await r.text() + tags = re.findall(r'/.*?/tag/(.*?)"', text) + if tags: + return tags[0] + else: + return None + + if __name__ == "__main__": pass