Skip to content

Commit fcef6cb

Browse files
Copilotsgerlach
andcommitted
Refactor: Extract helper methods and improve code quality
Co-authored-by: sgerlach <[email protected]>
1 parent f180ecf commit fcef6cb

File tree

2 files changed

+56
-45
lines changed

2 files changed

+56
-45
lines changed

stackhawk_mcp/server.py

Lines changed: 44 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -2192,24 +2192,12 @@ async def _analyze_sensitive_data_trends(self, org_id: str, analysis_period: str
21922192
async def _check_repository_attack_surface(self, repo_name: str = None, org_id: str = None, include_vulnerabilities: bool = True, include_apps: bool = True, **kwargs) -> Dict[str, Any]:
21932193
"""Check if a repository name exists in StackHawk attack surface and get security information"""
21942194
try:
2195-
# Get repo name from current directory if not provided
2196-
if not repo_name:
2197-
repo_name = os.path.basename(os.getcwd())
2198-
2199-
# Get org_id if not provided
2200-
if not org_id:
2201-
user_info = await self.client.get_user_info()
2202-
org_id = user_info["user"]["external"]["organizations"][0]["organization"]["id"]
2195+
# Get repo name and org_id using helper methods
2196+
repo_name = self._get_current_repository_name(repo_name)
2197+
org_id = await self._get_organization_id(org_id)
22032198

2204-
# List all repositories in the organization
2205-
repos_response = await self.client.list_repositories(org_id, pageSize=1000)
2206-
repositories = repos_response.get("repositories", [])
2207-
2208-
# Find matching repositories (case-insensitive)
2209-
matching_repos = [
2210-
repo for repo in repositories
2211-
if repo.get("name", "").lower() == repo_name.lower()
2212-
]
2199+
# Find matching repositories
2200+
matching_repos = await self._find_matching_repositories(repo_name, org_id)
22132201

22142202
result = {
22152203
"repository_name": repo_name,
@@ -2276,24 +2264,12 @@ async def _check_repository_attack_surface(self, repo_name: str = None, org_id:
22762264
async def _check_repository_sensitive_data(self, repo_name: str = None, org_id: str = None, data_type_filter: str = "All", include_remediation: bool = True, **kwargs) -> Dict[str, Any]:
22772265
"""Check if a repository has sensitive data findings in StackHawk"""
22782266
try:
2279-
# Get repo name from current directory if not provided
2280-
if not repo_name:
2281-
repo_name = os.path.basename(os.getcwd())
2282-
2283-
# Get org_id if not provided
2284-
if not org_id:
2285-
user_info = await self.client.get_user_info()
2286-
org_id = user_info["user"]["external"]["organizations"][0]["organization"]["id"]
2267+
# Get repo name and org_id using helper methods
2268+
repo_name = self._get_current_repository_name(repo_name)
2269+
org_id = await self._get_organization_id(org_id)
22872270

2288-
# List all repositories in the organization
2289-
repos_response = await self.client.list_repositories(org_id, pageSize=1000)
2290-
repositories = repos_response.get("repositories", [])
2291-
2292-
# Find matching repositories (case-insensitive)
2293-
matching_repos = [
2294-
repo for repo in repositories
2295-
if repo.get("name", "").lower() == repo_name.lower()
2296-
]
2271+
# Find matching repositories
2272+
matching_repos = await self._find_matching_repositories(repo_name, org_id)
22972273

22982274
result = {
22992275
"repository_name": repo_name,
@@ -2354,10 +2330,8 @@ async def _check_repository_sensitive_data(self, repo_name: str = None, org_id:
23542330
async def _list_application_repository_connections(self, org_id: str = None, include_repo_details: bool = True, include_app_details: bool = True, filter_connected_only: bool = False, **kwargs) -> Dict[str, Any]:
23552331
"""List connections between StackHawk applications and code repositories"""
23562332
try:
2357-
# Get org_id if not provided
2358-
if not org_id:
2359-
user_info = await self.client.get_user_info()
2360-
org_id = user_info["user"]["external"]["organizations"][0]["organization"]["id"]
2333+
# Get org_id using helper method
2334+
org_id = await self._get_organization_id(org_id)
23612335

23622336
# Get all applications and repositories
23632337
apps_response = await self.client.list_applications(org_id, pageSize=1000)
@@ -2463,6 +2437,36 @@ async def _list_application_repository_connections(self, org_id: str = None, inc
24632437
"organization_id": org_id
24642438
}
24652439

2440+
def _get_current_repository_name(self, repo_name: str = None) -> str:
2441+
"""Get repository name from parameter or auto-detect from current directory"""
2442+
if repo_name:
2443+
return repo_name
2444+
return os.path.basename(os.getcwd())
2445+
2446+
async def _get_organization_id(self, org_id: str = None) -> str:
2447+
"""Get organization ID from parameter or auto-detect from user info"""
2448+
if org_id:
2449+
return org_id
2450+
2451+
user_info = await self.client.get_user_info()
2452+
organizations = user_info.get("user", {}).get("external", {}).get("organizations", [])
2453+
2454+
if not organizations:
2455+
raise ValueError("No organizations found for user")
2456+
2457+
return organizations[0]["organization"]["id"]
2458+
2459+
async def _find_matching_repositories(self, repo_name: str, org_id: str) -> List[Dict[str, Any]]:
2460+
"""Find repositories matching the given name (case-insensitive)"""
2461+
repos_response = await self.client.list_repositories(org_id, pageSize=1000)
2462+
repositories = repos_response.get("repositories", [])
2463+
2464+
# Find matching repositories (case-insensitive)
2465+
return [
2466+
repo for repo in repositories
2467+
if repo.get("name", "").lower() == repo_name.lower()
2468+
]
2469+
24662470
def _calculate_name_similarity(self, name1: str, name2: str) -> float:
24672471
"""Calculate similarity between two names using simple string matching"""
24682472
if not name1 or not name2:
@@ -2483,10 +2487,8 @@ def _calculate_name_similarity(self, name1: str, name2: str) -> float:
24832487
async def _get_comprehensive_sensitive_data_summary(self, org_id: str = None, time_period: str = "30d", include_trends: bool = True, include_critical_only: bool = False, include_recommendations: bool = True, group_by: str = "data_type", **kwargs) -> Dict[str, Any]:
24842488
"""Get a comprehensive sensitive data summary combining multiple analysis approaches"""
24852489
try:
2486-
# Get org_id if not provided
2487-
if not org_id:
2488-
user_info = await self.client.get_user_info()
2489-
org_id = user_info["user"]["external"]["organizations"][0]["organization"]["id"]
2490+
# Get org_id using helper method
2491+
org_id = await self._get_organization_id(org_id)
24902492

24912493
# Get all sensitive data findings
24922494
if time_period == "all":

test_new_tools.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ async def test_new_tools():
4747
# Test 1: Check Repository Attack Surface
4848
print("1. Testing check_repository_attack_surface...")
4949
try:
50-
# Test with the current repository name
51-
current_repo = "stackhawk-mcp" # This repo
50+
# Test with the current repository name (auto-detected from directory)
51+
current_repo = os.path.basename(os.getcwd()) # Dynamic detection
5252
result = await server._check_repository_attack_surface(
5353
repo_name=current_repo,
5454
include_vulnerabilities=True,
@@ -159,7 +159,16 @@ async def test_new_tools():
159159
"get_sensitive_data_summary"
160160
]
161161

162-
found_tools = [tool.name for tool in tools if tool.name in new_tool_names]
162+
# Handle different possible return types from _list_tools_handler
163+
tool_names = []
164+
if tools:
165+
for tool in tools:
166+
if hasattr(tool, 'name'):
167+
tool_names.append(tool.name)
168+
elif isinstance(tool, dict) and 'name' in tool:
169+
tool_names.append(tool['name'])
170+
171+
found_tools = [name for name in tool_names if name in new_tool_names]
163172
print(f"✅ Found {len(found_tools)}/{len(new_tool_names)} new tools in MCP interface")
164173
for tool_name in found_tools:
165174
print(f" ✓ {tool_name}")

0 commit comments

Comments
 (0)