import os import sys import pandas as pd import matplotlib.pyplot as plt import seaborn as sns import re def main(input_directory, pattern, kind): # Initialize an empty DataFrame to store the data df_all_deltas = pd.DataFrame() # Include the '*_DELTA_*.xml3.csv' part in the pattern full_pattern = f'{pattern}_DELTA_.*\.xmlv3\.csv' # Load all *_DELTA_*.xml3.csv files into a single DataFrame for filename in os.listdir(input_directory): if re.match(full_pattern, filename): filepath = os.path.join(input_directory, filename) df = pd.read_csv(filepath) df['Task Set'] = filename # Add a column indicating the task set df_all_deltas = pd.concat([df_all_deltas, df]) # Reset the index of the combined DataFrame df_all_deltas = df_all_deltas.reset_index(drop=True) # Extract numeric task numbers and sort them task_numbers = df_all_deltas['task'].str.extract(r'(\d+)').astype(int) df_all_deltas['num'] = task_numbers df_all_deltas = df_all_deltas.sort_values(by='num') # For each task, extract the number of improvements, deteriorations and unchanged df_grouped = pd.DataFrame (columns=['Task', 'Imrovements', 'Deteriorations', 'Unchanged']) # Count improvements, deteriorations, and unchanged values for each task for task in df_all_deltas['task'].unique(): task_data = df_all_deltas[df_all_deltas['task'] == task] improvements = (task_data[kind] > 0).sum() deteriorations = (task_data[kind] < 0).sum() unchanged = (task_data[kind] == 0).sum() new_row = {'Task' : task, 'Imrovements' : improvements, 'Deteriorations' : deteriorations, 'Unchanged' : unchanged} df_grouped.loc[len (df_grouped)] = new_row # Plot grouped bar graph plt.figure(figsize=(12, 8)) df_grouped.plot(x='Task', kind='bar', ax=plt.gca()) title_fontsize = 16 # Adjust this size as needed axes_label_fontsize = 14 # Adjust this size as needed plt.title(f'Grouped {kind.capitalize()} Values', fontsize=title_fontsize) plt.xlabel('Task', fontsize=axes_label_fontsize) plt.ylabel(kind.capitalize(), fontsize=axes_label_fontsize) plt.legend(title='Nature') # Adjust the tick parameters plt.xticks(rotation=0, fontsize=12) # Adjust fontsize as needed plt.yticks(fontsize=12) # Adjust fontsize as needed # Adjusting grid and removing margins plt.grid(axis='y', linestyle='--', alpha=0.7) plt.margins(x=0, y=0) # This reduces the margins around the plot # We want the plot to be tightly fitted plt.tight_layout() # Remove special characters and use the cleaned pattern in the output filename clean_pattern = re.sub(r'[^a-zA-Z0-9_]', '', pattern) output_filename = f'response_time_diffs_grouped_{kind}_{clean_pattern}.png' # Save the plot as a PNG figure in the same directory output_filepath = os.path.join(input_directory, output_filename) plt.savefig(output_filepath) # Debug df_grouped.to_csv(os.path.join(input_directory, f'response_time_diffs_grouped_{kind}_{clean_pattern}.csv'), index=False) # Show the plot #plt.show() if __name__ == "__main__": if len (sys.argv) < 1 or len(sys.argv) > 4 or (sys.argv[1] != "best" and sys.argv[1] != "worst" and sys.argv[1] != "average" and sys.argv[1] != "missed"): print("Usage: python script.py average|best|worst|missed [ input_directory [pattern] ]") sys.exit(1) kind = sys.argv[1] input_directory = sys.argv[2] if len(sys.argv) > 2 else '.' pattern = sys.argv[3] if len(sys.argv) > 3 else '.*' print (kind, input_directory, pattern) main (input_directory, pattern, kind)