|
13 | 13 | import tempfile |
14 | 14 | import uuid |
15 | 15 | import warnings |
| 16 | +import zipfile |
16 | 17 | from collections import defaultdict |
17 | 18 | from http import HTTPStatus |
18 | 19 | from http.cookiejar import CookieJar |
@@ -3043,6 +3044,111 @@ def set_repo_visibility(self, |
3043 | 3044 |
|
3044 | 3045 | return resp |
3045 | 3046 |
|
| 3047 | + # ============= Collection API ============= |
| 3048 | + def get_collection(self, |
| 3049 | + collection_id: str, |
| 3050 | + repo_type: str = 'skill', |
| 3051 | + page_number: int = 1, |
| 3052 | + page_size: int = 50) -> dict: |
| 3053 | + """Get collection details and its elements. |
| 3054 | +
|
| 3055 | + Args: |
| 3056 | + collection_id (str): The collection ID (Fid). |
| 3057 | + repo_type (str): Element type filter, only 'skill' is supported currently. |
| 3058 | + page_number (int): Page number for pagination. |
| 3059 | + page_size (int): Page size for pagination. |
| 3060 | +
|
| 3061 | + Returns: |
| 3062 | + dict: Collection details including elements. |
| 3063 | +
|
| 3064 | + Raises: |
| 3065 | + ValueError: If repo_type is not 'skill'. |
| 3066 | + RequestError: If the API request fails. |
| 3067 | + """ |
| 3068 | + if repo_type != 'skill': |
| 3069 | + raise ValueError( |
| 3070 | + f'repo_type={repo_type} is not supported, ' |
| 3071 | + 'only "skill" is currently supported.') |
| 3072 | + cookies = self.get_cookies() |
| 3073 | + path = f'{self.endpoint}/api/v1/collections' |
| 3074 | + params = { |
| 3075 | + 'Fid': collection_id, |
| 3076 | + 'ElementType': repo_type, |
| 3077 | + 'PageNumber': page_number, |
| 3078 | + 'PageSize': page_size, |
| 3079 | + } |
| 3080 | + r = self.session.get(path, params=params, cookies=cookies, |
| 3081 | + headers=self.builder_headers(self.headers)) |
| 3082 | + raise_for_http_status(r) |
| 3083 | + d = r.json() |
| 3084 | + raise_on_error(d) |
| 3085 | + return d[API_RESPONSE_FIELD_DATA] |
| 3086 | + |
| 3087 | + def download_skill(self, skill_id: str, |
| 3088 | + local_dir: Optional[str] = None) -> str: |
| 3089 | + """Download a single skill archive and extract it. |
| 3090 | +
|
| 3091 | + Args: |
| 3092 | + skill_id (str): The skill identifier in format '<path>/<name>'. |
| 3093 | + local_dir (Optional[str]): Target directory for extraction. |
| 3094 | + Defaults to current directory. |
| 3095 | +
|
| 3096 | + Returns: |
| 3097 | + str: Path to the extracted skill directory. |
| 3098 | +
|
| 3099 | + Raises: |
| 3100 | + ValueError: If skill_id format is invalid. |
| 3101 | + RequestError: If the download request fails. |
| 3102 | + """ |
| 3103 | + element_path, element_name = RepoUtils.validate_repo_id(skill_id) |
| 3104 | + |
| 3105 | + cookies = self.get_cookies() |
| 3106 | + url = f'{self.endpoint}/api/v1/skills/{element_path}/{element_name}/archive/zip/master' |
| 3107 | + |
| 3108 | + if local_dir is None: |
| 3109 | + local_dir = os.getcwd() |
| 3110 | + os.makedirs(local_dir, exist_ok=True) |
| 3111 | + |
| 3112 | + # Build skill directory name: <element_path>__<element_name>__master |
| 3113 | + skill_dir_name = f'{element_path}__{element_name}__master' |
| 3114 | + skill_dir = os.path.join(local_dir, skill_dir_name) |
| 3115 | + |
| 3116 | + r = self.session.get(url, stream=True, cookies=cookies, |
| 3117 | + headers=self.builder_headers(self.headers)) |
| 3118 | + raise_for_http_status(r) |
| 3119 | + |
| 3120 | + # Save to temp zip file then extract |
| 3121 | + zip_path = os.path.join(local_dir, f'{element_name}.zip') |
| 3122 | + try: |
| 3123 | + with open(zip_path, 'wb') as f: |
| 3124 | + for chunk in r.iter_content(chunk_size=8192): |
| 3125 | + if chunk: |
| 3126 | + f.write(chunk) |
| 3127 | + |
| 3128 | + # Clean existing directory to avoid corrupted state |
| 3129 | + if os.path.exists(skill_dir): |
| 3130 | + shutil.rmtree(skill_dir) |
| 3131 | + os.makedirs(skill_dir, exist_ok=True) |
| 3132 | + with zipfile.ZipFile(zip_path, 'r') as zf: |
| 3133 | + zf.extractall(skill_dir) |
| 3134 | + |
| 3135 | + # Flatten if zip contains a single top-level directory |
| 3136 | + entries = os.listdir(skill_dir) |
| 3137 | + if len(entries) == 1: |
| 3138 | + nested_dir = os.path.join(skill_dir, entries[0]) |
| 3139 | + if os.path.isdir(nested_dir): |
| 3140 | + for item in os.listdir(nested_dir): |
| 3141 | + shutil.move( |
| 3142 | + os.path.join(nested_dir, item), |
| 3143 | + os.path.join(skill_dir, item)) |
| 3144 | + os.rmdir(nested_dir) |
| 3145 | + finally: |
| 3146 | + if os.path.exists(zip_path): |
| 3147 | + os.remove(zip_path) |
| 3148 | + |
| 3149 | + logger.info(f'Skill {element_path}/{element_name} downloaded to {skill_dir}') |
| 3150 | + return skill_dir |
| 3151 | + |
3046 | 3152 |
|
3047 | 3153 | class ModelScopeConfig: |
3048 | 3154 | path_credential = expanduser(MODELSCOPE_CREDENTIALS_PATH) |
|
0 commit comments