From 3904f846340ead650c52fe4e81489631d0713edb Mon Sep 17 00:00:00 2001 From: c72-taylor <Charlie3.Taylor@live.uwe.ac.uk> Date: Wed, 7 May 2025 19:52:06 +0100 Subject: [PATCH] model --- Working Models/Linear_training.py | 23 ++++++++++------------- 1 file changed, 10 insertions(+), 13 deletions(-) diff --git a/Working Models/Linear_training.py b/Working Models/Linear_training.py index eb3d013..07a366d 100644 --- a/Working Models/Linear_training.py +++ b/Working Models/Linear_training.py @@ -201,12 +201,12 @@ def main(): sys.stderr = null_file try: - # Load the data using absolute path - file_path = f"{sys.argv[2]}" - data = pd.read_csv(file_path) + # Get command line arguments + target_col = sys.argv[1] + file_path = sys.argv[2] - # Hard-coded target column - target_col = target_col = f"{sys.argv[1]}" + # Load the data + data = pd.read_csv(file_path) # Preprocess the data X_clean, y_clean = preprocess_data(data, target_col) @@ -262,17 +262,15 @@ def main(): "model": model_params } - # Save parameters to JSON file - input_dir = os.path.dirname(file_path) if os.path.dirname(file_path) else "." - input_filename = os.path.basename(file_path) - input_name = os.path.splitext(input_filename)[0] + # Save parameters to JSON file in the script's directory + script_dir = os.path.dirname(os.path.abspath(__file__)) + model_path = os.path.join(script_dir, "Piecewise_Trained.json") - model_path = os.path.join(input_dir, f"{input_name}_model.json") with open(model_path, 'w') as f: json.dump(prediction_params, f, indent=2) - # Save human-readable results - output_path = os.path.join(input_dir, f"{input_name}_training_results.txt") + # Save human-readable results to the same directory + output_path = os.path.join(script_dir, "Training_Results.txt") with open(output_path, "w") as f: f.write(f"Train RMSE: {results['Train RMSE']}\n") f.write(f"Test RMSE: {results['Test RMSE']}\n") @@ -284,7 +282,6 @@ def main(): # Restore original stderr before printing sys.stderr = original_stderr - finally: # Restore original stderr and close the null file sys.stderr = original_stderr -- GitLab