Skip to content
Snippets Groups Projects
Commit e0c1612c authored by a272-jones's avatar a272-jones
Browse files

Replaced. with second part, but values are way too high

parent db3835de
No related branches found
No related tags found
No related merge requests found
......@@ -162,11 +162,141 @@ test_r2 = r2_score(y_test, y_pred_test)
### OUTPUTS
print(f"Train RMSE: {train_rmse:.2f}")
print(f"Test RMSE: {test_rmse:.2f}")
print(f"Train R² Score: {train_r2:.4f}")
print(f"Test R² Score: {test_r2:.4f}")
#print(f"Train RMSE: {train_rmse:.2f}")
#print(f"Test RMSE: {test_rmse:.2f}")
#print(f"Train R² Score: {train_r2:.4f}")
#print(f"Test R² Score: {test_r2:.4f}")
# Display the learned expression
print("\nBest symbolic expression:")
print(symbolic_reg._program)
#print("\nBest symbolic expression:")
#print(symbolic_reg._program)
#Frankenstein time!
class TrainedSymbolicRegressor:
def __init__(self):
self.model = None
self._program = symbolic_reg._program
self._setup_model()
def _setup_model(self):
# Define the predict function directly without printing
self.model = lambda X: self._predict_sample(X)
def _predict_sample(self, X):
predictions = np.zeros(X.shape[0])
try:
for i in range(X.shape[0]):
# Extract features
x_vals = {}
for j in range(min(19, X.shape[1])):
x_vals[f'X{j + 1}'] = X[i, j] if j < X.shape[1] else 0
X1 = x_vals.get('X1', 0)
X5 = x_vals.get('X5', 0)
X6 = x_vals.get('X6', 0)
X7 = x_vals.get('X7', 0)
X8 = x_vals.get('X8', 0)
X13 = x_vals.get('X13', 0)
X15 = x_vals.get('X15', 0)
def safe_sqrt(x):
return np.sqrt(max(0, x))
def safe_log(x):
return np.log(max(1e-10, abs(x)))
def safe_div(a, b):
return a / (b if abs(b) > 1e-10 else 1e-10)
term1 = safe_sqrt(abs(X5 - X7))
term2 = safe_sqrt(safe_div(term1, X6))
term3 = safe_log(abs(np.cos(X7)))
term4 = np.sin(safe_sqrt(safe_log(safe_sqrt(abs(np.cos(X15))))))
predictions[i] = safe_div(term2, term4)
except Exception:
pass
return predictions
def predict(self, X):
return self.model(X)
def main():
# Load the data
file_path = f"{sys.argv[2]}"
data = pd.read_csv(file_path)
# Target column from arguments
target_col = f"{sys.argv[1]}"
X = data.drop(target_col, axis=1)
y = data[target_col]
# Categorize columns by data type
numeric_cols = X.select_dtypes(include=[np.number]).columns.tolist()
categorical_cols = X.select_dtypes(include=['object']).columns.tolist()
# Process categorical columns
for col in categorical_cols:
# Check if column contains time periods
if col == 'Injury_Prognosis' or any(
re.search(r'\d+\s*(?:day|week|month|year)', str(val)) for val in X[col].dropna().iloc[:20]):
X[col] = X[col].apply(convert_time_period)
# Fill missing values with median
median_value = X[col].median()
X[col].fillna(median_value, inplace=True)
else:
# Label encoding for categorical variables
X[col].fillna("MISSING_VALUE", inplace=True)
le = LabelEncoder()
X[col] = le.fit_transform(X[col])
# Drop any remaining non-numeric columns
non_numeric_cols = X.select_dtypes(exclude=[np.number]).columns.tolist()
if non_numeric_cols:
X = X.drop(columns=non_numeric_cols)
# Handle missing values with imputation
num_imputer = SimpleImputer(strategy='median')
X_imputed = pd.DataFrame(num_imputer.fit_transform(X), columns=X.columns)
# Handle missing values in target
target_missing = y.isna().sum()
if target_missing > 0:
mask = y.notna()
X_clean = X_imputed[mask]
y_clean = y[mask]
else:
X_clean = X_imputed
y_clean = y.copy()
# Scale the features
scaler = StandardScaler()
X_scaled = scaler.fit_transform(X_clean)
# Create and use the pre-trained model
model = TrainedSymbolicRegressor()
predictions = model.predict(X_scaled)
# Create output dataframe and save to CSV
result_df = pd.DataFrame({
f"Predicted_{target_col}": predictions
})
# Get output file path - same directory as input file but with _predictions suffix
input_path = file_path
input_dir = os.path.dirname(input_path)
input_filename = os.path.basename(input_path)
input_name = os.path.splitext(input_filename)[0]
output_path = os.path.join(input_dir, f"{input_name}_predictions.csv")
result_df.to_csv(output_path, index=False)
print(output_path)
if __name__ == "__main__":
main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment