Skip to content
Snippets Groups Projects
Commit 9ad87329 authored by c72-taylor's avatar c72-taylor
Browse files

model

parent e09998f5
No related branches found
No related tags found
No related merge requests found
......@@ -51,14 +51,11 @@ def convert_time_period(value):
class PiecewiseLinearPredictor:
"""A simplified version of the model that only does predictions"""
def __init__(self, model_params):
self.segments = model_params["segments"]
self.feature_count = model_params["feature_count"]
def predict(self, X, segment_id):
"""Predict using the specified segment"""
segment = self.segments.get(str(segment_id))
if segment is None:
# If segment not found, use the first available segment
......@@ -78,7 +75,6 @@ class PiecewiseLinearPredictor:
def preprocess_data(data, target_col, feature_names):
"""Preprocess the data similar to the training process"""
if target_col in data.columns:
X = data.drop(target_col, axis=1)
y = data[target_col]
......@@ -132,7 +128,6 @@ def preprocess_data(data, target_col, feature_names):
def apply_scaler(X, scaler_params):
"""Apply scaling using the parameters from training"""
mean = np.array(scaler_params["mean"])
scale = np.array(scaler_params["scale"])
......@@ -142,8 +137,6 @@ def apply_scaler(X, scaler_params):
def find_segment(X_scaled, all_segments):
"""Find the most appropriate segment for each data point
This is a simplified approach since we don't have the tree model"""
# For simplicity, we'll use the first segment
# In a real implementation, you would use the tree model to determine the correct segment
segment_id = list(all_segments.keys())[0]
......@@ -159,19 +152,18 @@ def main():
sys.stderr = null_file
try:
# Load the data using absolute path
file_path = f"{sys.argv[2]}"
data = pd.read_csv(file_path)
# Hard-coded target column
target_col = f"{sys.argv[1]}"
# Load the model parameters
input_dir = os.path.dirname(file_path) if os.path.dirname(file_path) else "."
input_filename = os.path.basename(file_path)
input_name = os.path.splitext(input_filename)[0]
script_dir = os.path.dirname(os.path.abspath(__file__))
model_path = os.path.join(input_dir, f"{input_name}_model.json")
model_path = os.path.join(script_dir, "Piecewise_Trained.json")
with open(model_path, 'r') as f:
model_data = json.load(f)
......@@ -210,7 +202,6 @@ def main():
# Restore original stderr before printing
sys.stderr = original_stderr
print(f"Predictions saved to: {csv_output_path}")
finally:
# Restore original stderr and close the null file
......
This diff is collapsed.
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment