Source code for AutoRPE.UtilsWorkflow.BinaryTreeSearch

from logging import root
import signal
import time
import AutoRPE

[docs] class Counter: """ A simple counter class to track increments. """ def __init__(self): """ Initialize the counter with a starting value of 0. """ self.count = 0
[docs] def up(self, how_much: str=1): """ Increment the counter by the specified amount and return the previous count. Parameters: how_much (int, optional): The amount by which to increment the counter (default is 1). Returns: int: The previous value of the counter before the increment. """ self.count += how_much return self.count - 1
[docs] class GracefulKiller: """ This class is used to allow the user to stop an analysis at any point creating a pause_checkpoint.pkl which can be used in the future to restart the analysis from the same point. It uses the signal library to catch signals. """ kill_now = False def __init__(self): """ Initialize the signal handler to catch SIGINT and SIGTERM signals. """ signal.signal(signal.SIGINT, self.exit_gracefully) signal.signal(signal.SIGTERM, self.exit_gracefully) self.filename = None
[docs] def exit_gracefully(self, signum: int, frame): """ Handle termination signals by setting the kill flag and printing a message. Parameters: signum (int): The signal number received. frame (frame object): The current stack frame (unused in this implementation). """ print("Exit has been activated, will write the checkpoint as soon as possible and exit.") self.kill_now = True
[docs] class BinaryTreeSearch: """ Initializes the experiment with given parameters and sets up necessary components. Parameters: communicator (SSH): The communicator object used to manage communication between remote-local machines. local_folder (str): The local directory to store data. analysis_status (dict): Dictionary used for tracking the experiment's progress (status of each job of the analysis). job_template (str): Filename (path) of template used to generate jobs for the analysis. vault (Vault): The vault containing variables used for the analysis. original_precision_level (int): The precision level to use in the original run (dp=52, sp=23, hp=10). reduced_precision_level (int): The precision level for the reduced precision run (dp=52, sp=23, hp=10). accuracy_test (object): Object containing the accuracy test to validate results. max_running_jobs (int): The maximum number of jobs that can run simultaneously. output_filename (str): The name of the output file where results will be saved. experiment_name (str, optional): The name of the experiment. Defaults to "BinaryTreeSearch". """ def __init__(self, communicator: 'SSH', local_folder: str, analysis_status: dict, job_template: str, vault: 'Vault', original_precision_level: int, reduced_precision_level: int, accuracy_test: object, max_running_jobs: int, output_filename: str, experiment_name: str="BinaryTreeSearch"): # Starting SSH connection self.communicator = communicator self.max_running_jobs = max_running_jobs self.local_folder = local_folder # Parameters of the analysis self.original_precision_level = original_precision_level self.reduced_precision_level = reduced_precision_level # Experiment related parameter self.accuracy_test = accuracy_test self.job_template = job_template self.vault = vault self.analysis_variables = [v for v in self.vault.variables if v.id is not None and v.is_used][:] self.experiment_name = experiment_name # In case of starting from a checkpoint self.analysis_status = analysis_status # Name of the output file self.output_filename = output_filename # The tree job structure self.binary_tree = None # Used to give an incremental ID to jobs (for visualization) self.counter = Counter()
[docs] def setup(self, id_reduced_precision: list, forced_ids: list, id_banned_var: list): """ Sets up the experiment driver in preparation for starting the analysis. Parameters: id_reduced_precision (list): List of variable IDs to use with reduced precision. forced_ids (list): List of variable IDs that must be included in the analysis. id_banned_var (list): List of variable IDs that should be kept in original precision. Returns: None """ from AutoRPE.UtilsWorkflow.Job import Job from AutoRPE.UtilsWorkflow.BinaryTree import BinaryTree print("Starting analysis with %i variables" % len(id_reduced_precision)) # Create the root job of the search # The first set will include all the variables root_job = Job(id_reduced_precision=[ids for ids in id_reduced_precision if ids not in id_banned_var], forced_ids=[ids for ids in forced_ids if ids not in id_banned_var], analysis_variables=self.analysis_variables, banned_variables=id_banned_var, reduced_precision_level=self.reduced_precision_level, communicator=self.communicator, vault=self.vault, template=self.job_template, local_folder=self.local_folder, analysis_status=self.analysis_status, result_filename=self.accuracy_test.filename_to_evaluate, counter=self.counter) # Defining it's level to 0 root_job.level = 0 # Initialize the Tree with the root job binary_tree = BinaryTree(root_job, self.accuracy_test, local_folder=self.local_folder) # Save the experiment driver in the class object self.binary_tree = binary_tree
[docs] def root_job_succeeded(self): """ Checks if the root job of the binary tree search has successfully completed. Returns: bool: True if the root job succeeded, otherwise False. Raises: AssertionError: If the root job failed, an error is raised. """ # In case the root of the tree fails, raise an error if self.binary_tree.root.status == "FAILED": raise AssertionError("Whole set job ended with a failed result. Should check what happened.") # The status of the root will be SUSPENDED until the entire tree has been checked return self.binary_tree.root.status == "SUCCESS"
[docs] def recover_checkpoint(self): """ Recovers the analysis state from a previously saved checkpoint. Returns: BinaryTree: The binary tree object after recovery from the checkpoint. """ import pickle as pkl binary_tree = pkl.load(open(self.checkpoint_name, "rb")) binary_tree.accuracy_test.communicator = self.communicator all_jobs = binary_tree.all() + binary_tree.disinherited for job in all_jobs: job.communicator = self.communicator job.vault = self.vault job.template = self.job_template job.local_folder = self.local_folder job_with_no_parent = [job for job in all_jobs if job.parent is None and job.level == 0] root = job_with_no_parent[0] root.vault = self.vault return binary_tree
[docs] def initial_check(self, forced_id: list=[], banned_id: list=[]): """ Performs an initial test run with original precision to ensure the accuracy test works. Parameters: forced_id (list, optional): List of variable IDs that are forced into the test. banned_id (list, optional): List of variable IDs that are excluded from the test. Returns: None Raises: AssertionError: If the basic test doesn't pass, an error is raised. """ from AutoRPE.UtilsWorkflow.Job import Job import time # Make a original-precision run to be sure that the accuracy test works as expected for var in self.analysis_variables: var.precision = self.original_precision_level check_job = Job(id_reduced_precision=[], forced_ids=forced_id, analysis_variables=self.analysis_variables, banned_variables=banned_id, reduced_precision_level=self.reduced_precision_level, communicator=self.communicator, vault=self.vault, template=self.job_template, result_filename=self.accuracy_test.filename_to_evaluate, local_folder=self.local_folder, analysis_status=self.analysis_status, counter=0 ) # Submit job and wait until it finishes job_status = check_job.run_job() while job_status not in ["COMPLETED", "FAILED"]: time.sleep(5) job_status = check_job.remote_status # Check that the job ran successfully # if self.accuracy_test.evaluate_success(remote_rundir=check_job.remote_rundir): result = check_job.get_result(self.accuracy_test) if result == 'SUCCESS': print("Basic test was successful, continuing with the analysis") else: raise AssertionError("Basic simulation to check that the accuracy test works properly didn't pass")
[docs] def print_root_configuration(self): """ Prints the configuration of the root job to the specified output file. """ # Sort and keep only unique values line_format = "%5i ! %20s %25s %20s %5s\n" # set final variable precision for var in self.analysis_variables: if var.id not in self.binary_tree.root.banned_variables: var.precision = 23 with open(self.output_filename, "w") as output_file: # Add header output_file.write("# ID " + " " * 5 + "" + " " * 13 + "variable name" + " " * 11 + "subprogram name" + " " * 9 + "module name" + " " * 3 + "variable precision" + "\n" + 100 * "#" + "\n") # Put the var in ascending id order self.analysis_variables.sort(key=lambda x: x.id) # Print the variables to file for var in self.analysis_variables: output_file.write(line_format % (var.id, var.name, var.procedure.name, var.procedure.module.name, var.precision))
[docs] def update_banned_variables(self): """ Updates the list of banned variables (to be kept at original precision) and writes them to a new file. """ line_format = "%5i ! %20s %25s %20s %5s\n" with open("banned_variables_newtxt", "w") as output_file: # Add header output_file.write("# ID " + " " * 5 + "" + " " * 13 + "variable name" + " " * 11 + "subprogram name" + " " * 9 + "module name" + " " * 3 + "variable precision" + "\n" + 100 * "#" + "\n") # Put the var in ascending id order self.analysis_variables.sort(key=lambda x: x.id) # Print the variables to file for var in self.analysis_variables: if var.precision == 52: output_file.write(line_format % (var.id, var.name, var.procedure.name, var.procedure.module.name, var.precision))