Skip to content

Commit 3c9087d

Browse files
committed
siluette
1 parent 76973fd commit 3c9087d

File tree

4 files changed

+44
-4
lines changed

4 files changed

+44
-4
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
# Setup configuration
3030
setuptools.setup(
3131
name="simba-uw-tf-dev",
32-
version="2.4.6",
32+
version="2.4.8",
3333
author="Simon Nilsson, Jia Jie Choong, Sophia Hwang",
3434
author_email="[email protected]",
3535
description="Toolkit for computer classification and analysis of behaviors in experimental animals",

simba/SimBA.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -356,7 +356,7 @@ def activate(box, *args):
356356

357357
label_behavior_frm = CreateLabelFrameWithIcon(parent=tab7, header="LABEL BEHAVIOR", icon_name=Keys.DOCUMENTATION.value, icon_link=Links.LABEL_BEHAVIOR.value)
358358
select_video_btn_new = SimbaButton(parent=label_behavior_frm, txt="Select video (create new video annotation)", img='label_blue', txt_clr='navy', cmd=select_labelling_video, cmd_kwargs={'config_path': lambda :self.config_path, 'threshold_dict': lambda: None, 'setting': lambda: "from_scratch", 'continuing': lambda: False}, thread=False)
359-
select_video_btn_continue = SimbaButton(parent=label_behavior_frm, txt="Select video (continue existing video annotation)", img='label_yellow', txt_clr='darkgoldenrod', cmd=select_labelling_video, cmd_kwargs={'config_path': lambda: self.config_path, 'threshold_dict': lambda:None, 'setting': lambda:None, 'continuing': lambda:True}, thread=False)
359+
select_video_btn_continue = SimbaButton(parent=label_behavior_frm, txt="Select video (continue existing video annotation)", img='label_yellow', txt_clr='darkgoldenrod', cmd=select_labelling_video, cmd_kwargs={'config_path': lambda: self.config_path, 'threshold_dict': lambda:None, 'setting': lambda: None, 'continuing': lambda:True}, thread=False)
360360

361361
label_thirdpartyann = CreateLabelFrameWithIcon(parent=tab7, header="IMPORT THIRD-PARTY BEHAVIOR ANNOTATIONS", icon_name=Keys.DOCUMENTATION.value, icon_link=Links.THIRD_PARTY_ANNOTATION.value)
362362
button_importmars = SimbaButton(parent=label_thirdpartyann, txt="Import MARS Annotation (select folder with .annot files)", txt_clr="blue", cmd=self.importMARS, thread=False)

simba/labelling/labelling_interface.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def __init__(self,
6363
config_path: Union[str, os.PathLike],
6464
file_path: Union[str, os.PathLike],
6565
threshold_dict: Optional[Dict[str, float]] = None,
66-
setting: Literal["from_scratch", "pseudo"] = "pseudo",
66+
setting: Optional[Literal["from_scratch", "pseudo"]] = "from_scratch",
6767
continuing: Optional[bool] = False):
6868

6969
ConfigReader.__init__(self, config_path=config_path)
@@ -360,7 +360,8 @@ def select_labelling_video(config_path: Union[str, os.PathLike],
360360
check_file_exist_and_readable(file_path=config_path)
361361
if threshold_dict is not None:
362362
check_valid_dict(x=threshold_dict, valid_key_dtypes=(str,), valid_values_dtypes=(float,))
363-
check_str(name='setting', value=setting, options=('pseudo', "from_scratch"))
363+
if setting is not None:
364+
check_str(name='setting', value=setting, options=('pseudo', "from_scratch",))
364365
check_valid_boolean(value=[continuing], source=select_labelling_video.__name__)
365366

366367

simba/mixins/statistics_mixin.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3931,6 +3931,45 @@ def calinski_harabasz(x: np.ndarray, y: np.ndarray) -> float:
39313931
else:
39323932
return extra_dispersion * (x.shape[0] - n_labels) / denominator
39333933

3934+
def silhouette_score(self, x: np.ndarray, y: np.ndarray) -> float:
3935+
"""
3936+
Compute the silhouette score for the given dataset and labels.
3937+
3938+
:param np.ndarray x: The dataset as a 2D NumPy array of shape (n_samples, n_features).
3939+
:param np.ndarray y: Cluster labels for each data point as a 1D NumPy array of shape (n_samples,).
3940+
:returns: The average silhouette score for the dataset.
3941+
:rtype: float
3942+
3943+
:example:
3944+
>>> x, y = make_blobs(n_samples=10000, n_features=400, centers=5, cluster_std=10, center_box=(-1, 1))
3945+
>>> score = silhouette_score(x=x, y=y)
3946+
3947+
>>> from sklearn.metrics import silhouette_score as sklearn_silhouette # SKLEARN ALTERNATIVE
3948+
>>> score_sklearn = sklearn_silhouette(x, y)
3949+
3950+
"""
3951+
dists = cdist(x, x)
3952+
results = np.full(x.shape[0], fill_value=-1.0, dtype=np.float32)
3953+
cluster_ids = np.unique(y)
3954+
cluster_indices = {cluster_id: np.argwhere(y == cluster_id).flatten() for cluster_id in cluster_ids}
3955+
3956+
for i in range(x.shape[0]):
3957+
intra_idx = cluster_indices[y[i]]
3958+
if len(intra_idx) <= 1:
3959+
a_i = 0.0
3960+
else:
3961+
intra_distances = dists[i, intra_idx]
3962+
a_i = np.sum(intra_distances) / (intra_distances.shape[0] - 1)
3963+
b_i = np.inf
3964+
for cluster_id in cluster_ids:
3965+
if cluster_id != y[i]:
3966+
inter_idx = cluster_indices[cluster_id]
3967+
inter_distances = dists[i, inter_idx]
3968+
b_i = min(b_i, np.mean(inter_distances))
3969+
results[i] = (b_i - a_i) / max(a_i, b_i)
3970+
3971+
return np.mean(results)
3972+
39343973
@staticmethod
39353974
def adjusted_rand(x: np.ndarray, y: np.ndarray) -> float:
39363975
"""

0 commit comments

Comments
 (0)