Skip to content

Commit

Permalink
hpsearch fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
ShokouhZolfaghari committed Sep 7, 2024
1 parent 74ddd1f commit 6f7de5c
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 20 deletions.
14 changes: 9 additions & 5 deletions hp_search.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import sys

from src.dataset.dataset_factory import DatasetFactory
from src.evaluation.evaluation_metric_factory import EvaluationMetricFactory
from src.explainer.explainer_factory import ExplainerFactory
from src.oracle.embedder_factory import EmbedderFactory
from src.oracle.oracle_factory import OracleFactory
from src.plotters.plotter_factory import PlotterFactory
from src.utils.context import Context
from hpsearch.hp_tuner import HpTuner

Expand All @@ -16,11 +20,11 @@

context.factories['datasets'] = DatasetFactory(context)
context.factories['oracles'] = OracleFactory(context)
#context.factories['embedders'] = EmbedderFactory(context)
#context.factories['explainers'] = ExplainerFactory(context)
#context.factories['metrics'] = EvaluationMetricFactory(context.conf)
#context.factories['plotters'] = PlotterFactory(context)
context.factories['embedders'] = EmbedderFactory(context)
context.factories['explainers'] = ExplainerFactory(context)
context.factories['metrics'] = EvaluationMetricFactory(context.conf)
context.factories['plotters'] = PlotterFactory(context)

hp_tuner = HpTuner(context, 100)
hp_tuner = HpTuner(context, train_oracle=True, train_expainer=False, n_trials=100)
hp_tuner.optimize()

46 changes: 31 additions & 15 deletions hpsearch/hp_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

class HpTuner():

def __init__(self, context: Context, n_trials: int) -> None:
def __init__(self, context: Context, train_oracle: bool, train_expainer: bool, n_trials: int) -> None:
self.context = context
self.n_trials = n_trials
self.dataset = context.factories['datasets'].get_dataset(context.conf['do-pairs'][0]['dataset'])
Expand All @@ -20,6 +20,17 @@ def optimize(self):

self.logger.info("Start optimization.")

#search_space = {
# 'num_conv_layers': [1, 2, 3, 4, 5],
# 'num_dense_layers': [1, 2, 3, 4, 5],
# 'conv_booster': [1, 2, 3, 4, 5]
#}

# Define the GridSampler with the search space
#sampler = optuna.samplers.GridSampler(search_space)

# Create a study object and use grid search
#study = optuna.create_study(sampler=sampler, direction='maximize')

#Create a study object and optimize the objective function
study = optuna.create_study(direction='maximize')
Expand All @@ -37,26 +48,31 @@ def objective(self, trial):

#Suggest values of the hyperparameters using a trial object.
#batch_size = trial.trial.suggest_categorical('batch_size', [32, 64])
learning_rate = trial.suggest_float('lr', 1e-3, 1e-1)
weight_decay = trial.suggest_float('weight_decay', 1e-5, 1e-1)
num_conv_layers = trial.suggest_int('num_conv_layers', 1, 2)
num_dense_layers = trial.suggest_int('num_dense_layers', 1, 2)
conv_booster = trial.suggest_int('conv_booster', 1, 2)
learning_rate = trial.suggest_float('lr', 1e-5, 1, log=True)
weight_decay = trial.suggest_float('weight_decay', 1e-5, 1, log=True)
#num_conv_layers = trial.suggest_int('num_conv_layers', 1, 5)
#num_dense_layers = trial.suggest_int('num_dense_layers', 1, 5)
#conv_booster = trial.suggest_int('conv_booster', 1, 5)
linear_decay = trial.suggest_float('linear_decay', 0, 2)

self.oracle_config['parameters']['optimizer']['parameters']['lr'] = learning_rate
self.oracle_config['parameters']['optimizer']['parameters']['weight_decay'] = weight_decay
self.oracle_config['parameters']['model']['parameters']['num_conv_layers'] = num_conv_layers
self.oracle_config['parameters']['model']['parameters']['num_dense_layers'] = num_dense_layers
self.oracle_config['parameters']['model']['parameters']['conv_booster'] = conv_booster
self.oracle_config['parameters']['model']['parameters']['num_conv_layers'] = 4 #num_conv_layers
self.oracle_config['parameters']['model']['parameters']['num_dense_layers'] = 2 #num_dense_layers
self.oracle_config['parameters']['model']['parameters']['conv_booster'] = 5 #conv_booster
self.oracle_config['parameters']['model']['parameters']['linear_decay'] = linear_decay

dataset = self.context.factories['datasets'].get_dataset(self.dataset_config)
parameters = self.oracle_config['parameters']
self.logger.info(f'Parameters: {parameters}')
self.logger.info(f'Trial {trial.number}: Hyperparameters: {parameters}')
oracle = self.context.factories['oracles'].get_oracle(self.oracle_config, dataset)

mean_accuracy = oracle.mean_accuracy
if mean_accuracy is None:
raise ValueError("The oracle did not return a valid accuracy value.")
return mean_accuracy

# Check for accuracy metric
if hasattr(oracle, 'mean_accuracy'):
mean_accuracy = oracle.mean_accuracy
if mean_accuracy is None:
self.logger.error("The oracle did not return a valid accuracy value.")
else:
return mean_accuracy
else:
self.logger.error("Oracle does not have 'mean_accuracy' attribute.")

0 comments on commit 6f7de5c

Please sign in to comment.