Skip to content

Commit a49e1bc

Browse files
Feat/collections (#1656)
1 parent ce6a902 commit a49e1bc

File tree

8 files changed

+335
-4
lines changed

8 files changed

+335
-4
lines changed

modelscope/cli/cli.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from modelscope.cli.plugins import PluginsCMD
1414
from modelscope.cli.scancache import ScanCacheCMD
1515
from modelscope.cli.server import ServerCMD
16+
from modelscope.cli.skills import SkillsCMD
1617
from modelscope.cli.upload import UploadCMD
1718
from modelscope.hub.constants import MODELSCOPE_ASCII
1819
from modelscope.utils.logger import get_logger
@@ -36,6 +37,7 @@ def run_cmd():
3637

3738
CreateCMD.define_args(subparsers)
3839
DownloadCMD.define_args(subparsers)
40+
SkillsCMD.define_args(subparsers)
3941
UploadCMD.define_args(subparsers)
4042
ClearCacheCMD.define_args(subparsers)
4143
PluginsCMD.define_args(subparsers)

modelscope/cli/download.py

Lines changed: 62 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,20 @@
11
# Copyright (c) Alibaba, Inc. and its affiliates.
2-
import os
2+
import logging
33
from argparse import ArgumentParser
44

55
from modelscope.cli.base import CLICommand
6+
from modelscope.cli.utils import concurrent_download
67
from modelscope.hub.api import HubApi
7-
from modelscope.hub.constants import DEFAULT_MAX_WORKERS
8+
from modelscope.hub.constants import DEFAULT_MAX_WORKERS, DEFAULT_SKILLS_DIR
89
from modelscope.hub.file_download import (dataset_file_download,
910
model_file_download)
1011
from modelscope.hub.snapshot_download import (dataset_snapshot_download,
1112
snapshot_download)
1213
from modelscope.hub.utils.utils import convert_patterns
1314
from modelscope.utils.constant import DEFAULT_DATASET_REVISION
15+
from modelscope.utils.logger import get_logger
16+
17+
logger = get_logger(log_level=logging.WARNING)
1418

1519

1620
def subparser_func(args):
@@ -41,6 +45,11 @@ def define_args(parsers: ArgumentParser):
4145
type=str,
4246
help='The id of the dataset to be downloaded. For download, '
4347
'the id of either a model or dataset must be provided.')
48+
group.add_argument(
49+
'--collection',
50+
type=str,
51+
default=None,
52+
help='The ID of the collection to download (skills only)')
4453
parser.add_argument(
4554
'repo_id',
4655
type=str,
@@ -122,8 +131,8 @@ def execute(self):
122131
else:
123132
raise Exception('Not support repo-type: %s'
124133
% self.args.repo_type)
125-
if not self.args.model and not self.args.dataset:
126-
raise Exception('Model or dataset must be set.')
134+
if not self.args.model and not self.args.dataset and not self.args.collection:
135+
raise Exception('Model, dataset, or collection must be set.')
127136
cookies = None
128137
if self.args.token is not None:
129138
api = HubApi()
@@ -191,5 +200,54 @@ def execute(self):
191200
print(
192201
f'\nSuccessfully Downloaded from dataset {self.args.dataset}.\n'
193202
)
203+
elif self.args.collection:
204+
api = HubApi(token=self.args.token)
205+
local_dir = self.args.local_dir or DEFAULT_SKILLS_DIR
206+
data = api.get_collection(self.args.collection, repo_type='skill')
207+
elements = data.get('CollectionElements',
208+
{}).get('CollectionElementVoList', [])
209+
210+
logger.info(
211+
f'Collection {self.args.collection} has {len(elements)} elements.'
212+
)
213+
214+
if not elements:
215+
print(f'No skill elements found in collection: '
216+
f'{self.args.collection}')
217+
return
218+
219+
# Validate elements have required fields
220+
valid_elements = []
221+
for elem in elements:
222+
if not elem.get('ElementPath') or not elem.get('ElementName'):
223+
logger.warning('Skipping malformed collection element: %s',
224+
elem)
225+
continue
226+
valid_elements.append(elem)
227+
228+
if not valid_elements:
229+
print(f'No valid skill elements found in collection: '
230+
f'{self.args.collection}')
231+
return
232+
233+
print(f'Found {len(valid_elements)} skill(s) in collection, '
234+
f'downloading...')
235+
236+
def _download_one_skill(element):
237+
element_path = element['ElementPath']
238+
element_name = element['ElementName']
239+
skill_id = f'{element_path}/{element_name}'
240+
try:
241+
skill_dir = api.download_skill(
242+
skill_id=skill_id, local_dir=local_dir)
243+
return (skill_id, skill_dir, None)
244+
except Exception as e:
245+
return (skill_id, None, str(e))
246+
247+
concurrent_download(
248+
_download_one_skill,
249+
valid_elements,
250+
max_workers=self.args.max_workers,
251+
item_name='skill')
194252
else:
195253
pass # noop

modelscope/cli/skills.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
# Copyright (c) Alibaba, Inc. and its affiliates.
2+
import logging
3+
import sys
4+
from argparse import ArgumentParser
5+
6+
from modelscope.cli.base import CLICommand
7+
from modelscope.cli.utils import concurrent_download
8+
from modelscope.hub.api import HubApi
9+
from modelscope.hub.constants import DEFAULT_SKILLS_DIR
10+
from modelscope.utils.logger import get_logger
11+
12+
logger = get_logger(log_level=logging.WARNING)
13+
14+
15+
def subparser_func(args):
16+
"""Function which will be called for a specific sub parser."""
17+
return SkillsCMD(args)
18+
19+
20+
class SkillsCMD(CLICommand):
21+
"""Command for managing skills."""
22+
23+
name = 'skills'
24+
25+
def __init__(self, args):
26+
self.args = args
27+
28+
@staticmethod
29+
def define_args(parsers: ArgumentParser):
30+
"""Define args for skills command."""
31+
parser = parsers.add_parser(SkillsCMD.name)
32+
subparsers = parser.add_subparsers(
33+
dest='skills_action', help='skills subcommands')
34+
35+
# 'add' subcommand
36+
add_parser = subparsers.add_parser(
37+
'add', help='Download and install skills')
38+
add_parser.add_argument(
39+
'skill_ids',
40+
type=str,
41+
nargs='+',
42+
help='Skill IDs to download, in format: <path>/<name>')
43+
add_parser.add_argument(
44+
'--token',
45+
type=str,
46+
default=None,
47+
help='Access token for authentication')
48+
add_parser.add_argument(
49+
'--local_dir',
50+
type=str,
51+
default=None,
52+
help='Target directory for skills (default: ~/.agents/skills)')
53+
add_parser.add_argument(
54+
'--max-workers',
55+
type=int,
56+
default=8,
57+
help='Maximum concurrent downloads (default: 8)')
58+
add_parser.set_defaults(func=subparser_func)
59+
60+
def execute(self):
61+
if not hasattr(self.args,
62+
'skills_action') or not self.args.skills_action:
63+
print('Usage: modelscope skills add <skill_id1> <skill_id2> ...')
64+
return
65+
66+
if not hasattr(self.args, 'skill_ids') or not self.args.skill_ids:
67+
print('No skill IDs provided. Usage: modelscope skills add '
68+
'<skill_id1> <skill_id2> ...')
69+
return
70+
71+
api = HubApi(token=self.args.token)
72+
local_dir = self.args.local_dir or DEFAULT_SKILLS_DIR
73+
74+
skill_ids = self.args.skill_ids
75+
print(f'Downloading {len(skill_ids)} skill(s)...')
76+
77+
if len(skill_ids) == 1:
78+
# Single skill download
79+
try:
80+
skill_dir = api.download_skill(
81+
skill_id=skill_ids[0], local_dir=local_dir)
82+
print(f'Skill downloaded to: {skill_dir}')
83+
except Exception as e:
84+
print(f'Failed to download skill {skill_ids[0]}: {e}')
85+
sys.exit(1)
86+
else:
87+
# Multiple skills - concurrent download
88+
def _download_one(skill_id):
89+
try:
90+
skill_dir = api.download_skill(
91+
skill_id=skill_id, local_dir=local_dir)
92+
return (skill_id, skill_dir, None)
93+
except Exception as e:
94+
return (skill_id, None, str(e))
95+
96+
concurrent_download(
97+
_download_one,
98+
skill_ids,
99+
max_workers=self.args.max_workers,
100+
item_name='skill')

modelscope/cli/utils.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
# Copyright (c) Alibaba, Inc. and its affiliates.
2+
import sys
3+
from concurrent.futures import ThreadPoolExecutor, as_completed
4+
5+
6+
def concurrent_download(download_fn, items, max_workers=8, item_name='item'):
7+
"""Download multiple items concurrently with progress reporting.
8+
9+
Args:
10+
download_fn: Callable that takes an item and returns
11+
(identifier, result_path, error_string_or_None).
12+
items: List of items to download.
13+
max_workers (int): Maximum concurrent workers.
14+
item_name (str): Display name for the item type.
15+
16+
Returns:
17+
tuple: (succeeded_list, failed_list).
18+
"""
19+
succeeded = []
20+
failed = []
21+
22+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
23+
futures = {executor.submit(download_fn, item): item for item in items}
24+
for future in as_completed(futures):
25+
identifier, result_path, error = future.result()
26+
if error:
27+
failed.append((identifier, error))
28+
print(f'Failed to download {item_name} {identifier}: {error}')
29+
else:
30+
succeeded.append((identifier, result_path))
31+
print(f'Downloaded {item_name} {identifier} -> {result_path}')
32+
33+
print(f'\nDownload complete: {len(succeeded)} succeeded, '
34+
f'{len(failed)} failed')
35+
if failed:
36+
print(f'Failed {item_name}s:')
37+
for identifier, error in failed:
38+
print(f' {identifier}: {error}')
39+
sys.exit(1)
40+
41+
return succeeded, failed

modelscope/hub/api.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import tempfile
1414
import uuid
1515
import warnings
16+
import zipfile
1617
from collections import defaultdict
1718
from http import HTTPStatus
1819
from http.cookiejar import CookieJar
@@ -3043,6 +3044,111 @@ def set_repo_visibility(self,
30433044

30443045
return resp
30453046

3047+
# ============= Collection API =============
3048+
def get_collection(self,
3049+
collection_id: str,
3050+
repo_type: str = 'skill',
3051+
page_number: int = 1,
3052+
page_size: int = 50) -> dict:
3053+
"""Get collection details and its elements.
3054+
3055+
Args:
3056+
collection_id (str): The collection ID (Fid).
3057+
repo_type (str): Element type filter, only 'skill' is supported currently.
3058+
page_number (int): Page number for pagination.
3059+
page_size (int): Page size for pagination.
3060+
3061+
Returns:
3062+
dict: Collection details including elements.
3063+
3064+
Raises:
3065+
ValueError: If repo_type is not 'skill'.
3066+
RequestError: If the API request fails.
3067+
"""
3068+
if repo_type != 'skill':
3069+
raise ValueError(
3070+
f'repo_type={repo_type} is not supported, '
3071+
'only "skill" is currently supported.')
3072+
cookies = self.get_cookies()
3073+
path = f'{self.endpoint}/api/v1/collections'
3074+
params = {
3075+
'Fid': collection_id,
3076+
'ElementType': repo_type,
3077+
'PageNumber': page_number,
3078+
'PageSize': page_size,
3079+
}
3080+
r = self.session.get(path, params=params, cookies=cookies,
3081+
headers=self.builder_headers(self.headers))
3082+
raise_for_http_status(r)
3083+
d = r.json()
3084+
raise_on_error(d)
3085+
return d[API_RESPONSE_FIELD_DATA]
3086+
3087+
def download_skill(self, skill_id: str,
3088+
local_dir: Optional[str] = None) -> str:
3089+
"""Download a single skill archive and extract it.
3090+
3091+
Args:
3092+
skill_id (str): The skill identifier in format '<path>/<name>'.
3093+
local_dir (Optional[str]): Target directory for extraction.
3094+
Defaults to current directory.
3095+
3096+
Returns:
3097+
str: Path to the extracted skill directory.
3098+
3099+
Raises:
3100+
ValueError: If skill_id format is invalid.
3101+
RequestError: If the download request fails.
3102+
"""
3103+
element_path, element_name = RepoUtils.validate_repo_id(skill_id)
3104+
3105+
cookies = self.get_cookies()
3106+
url = f'{self.endpoint}/api/v1/skills/{element_path}/{element_name}/archive/zip/master'
3107+
3108+
if local_dir is None:
3109+
local_dir = os.getcwd()
3110+
os.makedirs(local_dir, exist_ok=True)
3111+
3112+
# Build skill directory name: <element_path>__<element_name>__master
3113+
skill_dir_name = f'{element_path}__{element_name}__master'
3114+
skill_dir = os.path.join(local_dir, skill_dir_name)
3115+
3116+
r = self.session.get(url, stream=True, cookies=cookies,
3117+
headers=self.builder_headers(self.headers))
3118+
raise_for_http_status(r)
3119+
3120+
# Save to temp zip file then extract
3121+
zip_path = os.path.join(local_dir, f'{element_name}.zip')
3122+
try:
3123+
with open(zip_path, 'wb') as f:
3124+
for chunk in r.iter_content(chunk_size=8192):
3125+
if chunk:
3126+
f.write(chunk)
3127+
3128+
# Clean existing directory to avoid corrupted state
3129+
if os.path.exists(skill_dir):
3130+
shutil.rmtree(skill_dir)
3131+
os.makedirs(skill_dir, exist_ok=True)
3132+
with zipfile.ZipFile(zip_path, 'r') as zf:
3133+
zf.extractall(skill_dir)
3134+
3135+
# Flatten if zip contains a single top-level directory
3136+
entries = os.listdir(skill_dir)
3137+
if len(entries) == 1:
3138+
nested_dir = os.path.join(skill_dir, entries[0])
3139+
if os.path.isdir(nested_dir):
3140+
for item in os.listdir(nested_dir):
3141+
shutil.move(
3142+
os.path.join(nested_dir, item),
3143+
os.path.join(skill_dir, item))
3144+
os.rmdir(nested_dir)
3145+
finally:
3146+
if os.path.exists(zip_path):
3147+
os.remove(zip_path)
3148+
3149+
logger.info(f'Skill {element_path}/{element_name} downloaded to {skill_dir}')
3150+
return skill_dir
3151+
30463152

30473153
class ModelScopeConfig:
30483154
path_credential = expanduser(MODELSCOPE_CREDENTIALS_PATH)

modelscope/hub/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
DEFAULT_MAX_WORKERS = int(
4242
os.getenv('DEFAULT_MAX_WORKERS', min(8,
4343
os.cpu_count() + 4)))
44+
DEFAULT_SKILLS_DIR = os.path.join(os.path.expanduser('~'), '.agents', 'skills')
4445

4546
# Upload check env
4647
UPLOAD_MAX_FILE_SIZE = int(

0 commit comments

Comments
 (0)