On this weblog publish, we’ll stroll by the entire lifecycle of constructing a churn prediction system — from cleansing and getting ready knowledge, tuning an XGBoost mannequin with Optuna, and evaluating efficiency, to serving real-time predictions by a FastAPI-powered API.
We’ll additionally containerize the complete resolution with Docker, making it straightforward to deploy, scale, and combine into manufacturing environments.
We’ll use a buyer churn dataset to coach a mannequin. You possibly can obtain it from the Kaggle hyperlink under.
import optuna
import pickle
import pandas as pd
import numpy as np
import seaborn as sns
import xgboost as xgb
import matplotlib.pyplot as plt
from sklearn.preprocessing import OrdinalEncoder
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.metrics import (
accuracy_score,
precision_score,
recall_score,
f1_score,
roc_auc_score,
confusion_matrix,
ConfusionMatrixDisplay
)
from optuna import Trial
from numpy import ndarray
from functools import partial
from typing import Any, Dict
First, the dataset is loaded:
df = pd.read_csv("churn.csv")
print(f"Prepare form: {df.form}")# Prepare form: (505207, 12)
df.head()
df.information()"""
RangeIndex: 505207 entries, 0 to 505206
Knowledge columns (complete 12 columns):
# Column Non-Null Depend Dtype
--- ------ -------------- -----
0 CustomerID 505206 non-null float64
1 Age 505206 non-null float64
2 Gender 505206 non-null object
3 Tenure 505206 non-null float64
4 Utilization Frequency 505206 non-null float64
5 Assist Calls 505206 non-null float64
6 Cost Delay 505206 non-null float64
7 Subscription Kind 505206 non-null object
8 Contract Size 505206 non-null object
9 Complete Spend 505206 non-null float64
10 Final Interplay 505206 non-null float64
11 Churn 505206 non-null float64
dtypes: float64(9), object(3)
reminiscence utilization: 46.3+ MB
"""
print(df["Churn"].value_counts())
"""
Churn
1.0 280492
0.0 224714
Title: depend, dtype: int64
"""
print(df.isnull().sum())
"""
CustomerID 1
Age 1
Gender 1
Tenure 1
Utilization Frequency 1
Assist Calls 1
Cost Delay 1
Subscription Kind 1
Contract Size 1
Complete Spend 1
Final Interplay 1
Churn 1
dtype: int64
"""
All rows with any lacking values are eliminated to make sure knowledge integrity.
The CustomerID
column, which uniquely identifies every buyer however does not contribute to mannequin studying, is dropped.
df.dropna(axis=0, how="any", inplace=True)
df.drop(columns=["CustomerID"], inplace=True)
df.reset_index(drop=True, inplace=True)
The following step is to separate the options from the goal and cut up the information into coaching and validation units.
Y = df["Churn"]
X = df.drop(columns=["Churn"])X_train, X_val, y_train, y_val = train_test_split(X, Y, test_size= 0.2 , shuffle=True, random_state=42)
To organize categorical options for modeling, the dataset undergoes encoding utilizing OrdinalEncoder
.
PROCESSORS = {}PROCESSORS["Subscription Type"] = OrdinalEncoder(classes=[["Basic", "Standard", "Premium", "Pro"]])
PROCESSORS["Contract Length"] = OrdinalEncoder(classes=[["Monthly", "Quarterly", "Annual"]])
PROCESSORS["Gender"] = OrdinalEncoder()
for column, processor in PROCESSORS.gadgets():
X_train[column] = processor.fit_transform(X_train[[column]]).astype(int)
X_val[column] = processor.rework(X_val[[column]]).astype(int)
The encoders are utilized to each the coaching and validation units. Throughout coaching, encoders are fitted to the coaching knowledge after which used to rework each units to make sure consistency.
Correlation matrix helps determine multicollinearity — when options are extremely correlated with each other — which might affect mannequin efficiency, particularly for linear fashions.
plt.determine(figsize=(10, 8))
sns.heatmap(X_train.corr(), annot=True, cmap='coolwarm')
plt.title('Correlation Matrix')
plt.present()
To scale back redundancy and potential overfitting, it’s helpful to get rid of options which might be extremely correlated with one another.
def remove_highly_correlated_features(knowledge: pd.DataFrame, threshold: float) -> pd.DataFrame:
corr_matrix = knowledge.corr().abs()higher = corr_matrix.the place(np.triu(np.ones(corr_matrix.form), ok=1).astype(bool))
to_drop = [column for column in upper.columns if any(upper[column] > threshold)]
print(f"Options to drop as a result of excessive correlation: {to_drop}")
df_reduced = knowledge.drop(columns=to_drop)
return df_reduced, to_drop
X_train, to_drop = remove_highly_correlated_features(X_train, threshold=0.5)
if to_drop:
X_val = X_val.drop(columns=to_drop)
"""
Options to drop as a result of excessive correlation: []
"""
On this case, the output signifies that no options exceeded the correlation threshold, so the datasets stay unchanged.
To fine-tune the hyperparameters of an XGBoost classifier, we leverage Optuna — an environment friendly hyperparameter optimization framework. This method automates the seek for the most effective mannequin configuration by exploring combos that maximize a selected analysis metric.
def optuna_objective(
trial: Trial,
param_grid: Dict[str, Any],
estimator: Any,
X: ndarray,
y: ndarray,
cross_params: Dict[str, Any]
) -> float:
trial_params = {key: worth(trial) if callable(worth) else worth for key, worth in param_grid.gadgets()}
mannequin = estimator(**trial_params)scores = cross_val_score(mannequin, X, y, **cross_params)
return scores.imply()
The optuna_objective
perform builds and evaluates a mannequin utilizing parameters advised by Optuna. It accepts a parameter area (param_grid
), mannequin estimator, and cross-validation setup. It returns the common ROC AUC rating throughout folds—an efficient metric for binary classification duties.
The area
dictionary defines the vary of values for every hyperparameter. These embody frequent XGBoost parameters like learning_rate
, max_depth
, and regularization phrases (reg_alpha
, reg_lambda
), with worth ranges guided by area greatest practices.
area = {
'n_estimators': lambda trial: trial.suggest_int('n_estimators', 5, 150),
'max_depth': lambda trial: trial.suggest_int('max_depth', 2, 5),
'learning_rate': lambda trial: trial.suggest_float('learning_rate', 0.01, 0.1),
'subsample': lambda trial: trial.suggest_float('subsample', 0.5, 1),
'colsample_bytree': lambda trial: trial.suggest_float('colsample_bytree', 0.5, 1),
'gamma': lambda trial: trial.suggest_float('gamma', 0, 0.3),
'reg_alpha': lambda trial: trial.suggest_float('reg_alpha', 1e-3, 10.0, log=True),
'reg_lambda': lambda trial: trial.suggest_float('reg_lambda', 1e-3, 10.0, log=True),
}cross_params = {
'cv': 5,
'scoring': 'roc_auc',
'n_jobs': -1,
'verbose': 0,
'error_score': 'increase',
}
estimator = xgb.XGBClassifier
Utilizing 5-fold cross-validation ensures that every configuration is evaluated on a number of knowledge splits, serving to keep away from overfitting to anyone partition.
optuna.create_study
initiates the optimization research, aiming to maximise the ROC AUC rating. The research.optimize
name runs the target perform for five trials—this may be scaled up for deeper tuning.
partial_objective = partial(optuna_objective, estimator=estimator, param_grid=area, X=X_train, y=y_train.ravel(), cross_params=cross_params)research = optuna.create_study(route='maximize')
research.optimize(partial_objective, n_trials=5)
After hyperparameter tuning with Optuna, the subsequent step is to coach the ultimate mannequin utilizing the most effective parameters discovered throughout optimization.
mannequin = estimator(**research.best_params)
mannequin.match(X_train, y_train)y_pred = mannequin.predict(X_val)
y_pred_proba = mannequin.predict_proba(X_val)[:, 1]
With the mannequin educated and predictions generated, it’s time to evaluate its efficiency utilizing key classification metrics. These metrics present a complete view of how effectively the mannequin is figuring out churners versus non-churners:
# Metrics
accuracy = accuracy_score(y_val, y_pred)
precision = precision_score(y_val, y_pred, zero_division=0)
recall = recall_score(y_val, y_pred)
f1 = f1_score(y_val, y_pred)
roc_auc = roc_auc_score(y_val, y_pred_proba)# Show
print(f"Accuracy : {accuracy:.4f}")
print(f"Precision : {precision:.4f}")
print(f"Recall : {recall:.4f}")
print(f"F1 Rating : {f1:.4f}")
print(f"ROC AUC Rating: {roc_auc:.4f}")
"""
Accuracy : 0.9184
Precision : 0.8961
Recall : 0.9648
F1 Rating : 0.9292
ROC AUC Rating: 0.9470
"""
- Accuracy: Measures the proportion of accurately predicted samples. At 91.84%, the mannequin demonstrates robust common efficiency.
- Precision: Signifies what number of predicted positives had been actually optimistic. A precision of 89.61% means that false positives are comparatively low.
- Recall: Captures what number of precise positives had been accurately recognized. The excessive recall of 96.48% displays the mannequin’s effectiveness in catching churn instances — essential in retention eventualities.
- F1 Rating: Harmonizes precision and recall, and at 92.92%, reveals that the mannequin balances these metrics effectively.
- ROC AUC: Displays the mannequin’s skill to discriminate between courses throughout thresholds. A rating of 0.9470 indicators wonderful discriminatory energy.
To additional interpret the mannequin’s efficiency, a confusion matrix gives a visible abstract of prediction outcomes.
# Confusion Matrix
cm = confusion_matrix(y_val, y_pred)
disp = ConfusionMatrixDisplay(confusion_matrix=cm)
disp.plot()
With the mannequin and encoders educated and validated, the ultimate step within the machine studying pipeline is to serialize these elements. This permits them to be reused in a manufacturing atmosphere, akin to an API, with out retraining.
pickle.dump(mannequin, open("./api/fashions/churn_model.pkl", "wb"))
pickle.dump(PROCESSORS["Subscription Type"], open("./api/fashions/subscription_type_encoder.pkl", "wb"))
pickle.dump(PROCESSORS["Contract Length"], open("./api/fashions/contract_length_encoder.pkl", "wb"))
pickle.dump(PROCESSORS["Gender"], open("./api/fashions/gender_encoder.pkl", "wb"))
The app.py
script units up a completely practical FastAPI service for buyer churn prediction, encapsulating mannequin loading, preprocessing, and inference.
import pickle
import uvicorn
import logging
import numpy as np
import pandas as pd
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Area
from contextlib import asynccontextmanagerlogging.basicConfig(degree=logging.INFO)
logger = logging.getLogger(__name__)
MODEL = None
GENDER_ENCODER = None
SUBSCRIPTION_ENCODER = None
CONTRACT_ENCODER = None
class CustomerData(BaseModel):
age: int = Area(..., ge=18, le=100, description="Buyer age (18-100)")
gender: str = Area(..., description="Buyer gender")
tenure: int = Area(..., ge=0, description="Tenure in months")
utilization: float = Area(..., ge=0, description="Utilization quantity")
support_calls: int = Area(..., ge=0, description="Variety of assist calls")
payment_delay: int = Area(..., ge=0, description="Cost delay in days")
subscription_type: str = Area(..., description="Kind of subscription")
contract_length: str = Area(..., description="Contract size")
total_spend: float = Area(..., ge=0, description="Complete quantity spent")
last_interaction: int = Area(..., ge=0, description="Days since final interplay")
class Config:
schema_extra = {
"instance": {
"age": 35,
"gender": "Male",
"tenure": 24,
"utilization": 450.5,
"support_calls": 3,
"payment_delay": 2,
"subscription_type": "Premium",
"contract_length": "Annual",
"total_spend": 1250.75,
"last_interaction": 7
}
}
class PredictionResponse(BaseModel):
churn_probability: float = Area(..., description="Chance of buyer churn (0-1)")
churn_prediction: str = Area(..., description="Churn prediction (Sure/No)")
risk_level: str = Area(..., description="Threat degree (Low/Medium/Excessive)")
@asynccontextmanager
async def lifespan(app: FastAPI):
world MODEL, GENDER_ENCODER, SUBSCRIPTION_ENCODER, CONTRACT_ENCODER
strive:
with open('/app/fashions/churn_model.pkl', 'rb') as f:
MODEL = pickle.load(f)
logger.information("XGBoost mannequin loaded efficiently")
with open('/app/fashions/gender_encoder.pkl', 'rb') as f:
GENDER_ENCODER = pickle.load(f)
logger.information("Gender encoder loaded efficiently")
with open('/app/fashions/subscription_type_encoder.pkl', 'rb') as f:
SUBSCRIPTION_ENCODER = pickle.load(f)
logger.information("Subscription encoder loaded efficiently")
with open('/app/fashions/contract_length_encoder.pkl', 'rb') as f:
CONTRACT_ENCODER = pickle.load(f)
logger.information("Contract encoder loaded efficiently")
besides Exception as e:
logger.error(f"Error loading fashions: {str(e)}")
increase e
yield
logger.information("Shutting down software")
app = FastAPI(
title="Buyer Churn Prediction API",
description="API for predicting buyer churn utilizing XGBoost mannequin",
model="1.0.0",
lifespan=lifespan
)
def preprocess_data(knowledge: CustomerData) -> pd.DataFrame:
"""Preprocess enter knowledge by making use of encoders and creating DataFrame"""
data_dict = knowledge.model_dump()
strive:
gender_encoded = GENDER_ENCODER.rework(np.array([data_dict['gender']]).reshape(-1, 1))
if hasattr(gender_encoded, 'toarray'): # OneHotEncoder case
gender_encoded = gender_encoded.toarray()[0]
data_dict['gender'] = gender_encoded[0] if len(gender_encoded) == 1 else gender_encoded
else:
data_dict['gender'] = int(gender_encoded[0])
subscription_encoded = SUBSCRIPTION_ENCODER.rework(np.array([data_dict['subscription_type']]).reshape(-1, 1))
if hasattr(subscription_encoded, 'toarray'):
subscription_encoded = subscription_encoded.toarray()[0]
data_dict['subscription_type'] = subscription_encoded[0] if len(subscription_encoded) == 1 else subscription_encoded
else:
data_dict['subscription_type'] = int(subscription_encoded[0])
contract_encoded = CONTRACT_ENCODER.rework(np.array([data_dict['contract_length']]).reshape(-1, 1))
if hasattr(contract_encoded, 'toarray'):
contract_encoded = contract_encoded.toarray()[0]
data_dict['contract_length'] = contract_encoded[0] if len(contract_encoded) == 1 else contract_encoded
else:
data_dict['contract_length'] = int(contract_encoded[0])
besides ValueError as e:
increase HTTPException(status_code=400, element=f"Invalid categorical worth: {str(e)}")
besides Exception as e:
increase HTTPException(status_code=400, element=f"Encoding error: {str(e)}")
feature_order = [
'age', 'gender', 'tenure', 'usage',
'support_calls', 'payment_delay', 'subscription_type',
'contract_length', 'total_spend', 'last_interaction'
]
processed_data = {
'age': int(data_dict['age']),
'gender': data_dict['gender'],
'tenure': int(data_dict['tenure']),
'utilization': float(data_dict['usage']),
'support_calls': int(data_dict['support_calls']),
'payment_delay': int(data_dict['payment_delay']),
'subscription_type': data_dict['subscription_type'],
'contract_length': data_dict['contract_length'],
'total_spend': float(data_dict['total_spend']),
'last_interaction': int(data_dict['last_interaction'])
}
df = pd.DataFrame([processed_data], columns=feature_order)
logger.information(f"DataFrame dtypes: {df.dtypes.to_dict()}")
return df
def get_risk_level(likelihood: float) -> str:
if likelihood < 0.3:
return "Low"
elif likelihood < 0.7:
return "Medium"
else:
return "Excessive"
@app.get("/")
async def root():
return {"message": "Buyer Churn Prediction API", "standing": "wholesome"}
@app.get("/well being")
async def health_check() -> dict:
return {
"standing": "wholesome",
"model_loaded": MODEL just isn't None,
"encoders_loaded": all([
GENDER_ENCODER is not None,
SUBSCRIPTION_ENCODER is not None,
CONTRACT_ENCODER is not None
])
}
@app.publish("/predict", response_model=PredictionResponse)
async def predict_churn(customer_data: CustomerData):
"""Predict buyer churn likelihood"""
if MODEL is None:
increase HTTPException(status_code=500, element="Mannequin not loaded")
strive:
processed_data = preprocess_data(customer_data)
logger.information(f"Processed knowledge form: {processed_data.form}")
logger.information(f"Processed knowledge dtypes: {processed_data.dtypes.to_dict()}")
strive:
churn_probability = MODEL.predict_proba(processed_data)[0][1] # Chance of churn (class 1)
besides Exception as xgb_error:
logger.warning(f"Commonplace prediction failed: {xgb_error}")
churn_probability = MODEL.predict_proba(processed_data.values)[0][1]
churn_prediction = "Sure" if churn_probability > 0.5 else "No"
risk_level = get_risk_level(churn_probability)
logger.information(f"Prediction made - Chance: {churn_probability:.4f}, Prediction: {churn_prediction}")
return PredictionResponse(
churn_probability=spherical(churn_probability, 4),
churn_prediction=churn_prediction,
risk_level=risk_level
)
besides Exception as e:
logger.error(f"Prediction error: {str(e)}")
increase HTTPException(status_code=500, element=f"Prediction failed: {str(e)}")
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8082)
- Mannequin and Encoder Initialization
@asynccontextmanager
async def lifespan(app: FastAPI):
world MODEL, GENDER_ENCODER, SUBSCRIPTION_ENCODER, CONTRACT_ENCODERstrive:
with open('/app/fashions/churn_model.pkl', 'rb') as f:
MODEL = pickle.load(f)
logger.information("XGBoost mannequin loaded efficiently")
with open('/app/fashions/gender_encoder.pkl', 'rb') as f:
GENDER_ENCODER = pickle.load(f)
logger.information("Gender encoder loaded efficiently")
with open('/app/fashions/subscription_type_encoder.pkl', 'rb') as f:
SUBSCRIPTION_ENCODER = pickle.load(f)
logger.information("Subscription encoder loaded efficiently")
with open('/app/fashions/contract_length_encoder.pkl', 'rb') as f:
CONTRACT_ENCODER = pickle.load(f)
logger.information("Contract encoder loaded efficiently")
besides Exception as e:
logger.error(f"Error loading fashions: {str(e)}")
increase e
yield
logger.information("Shutting down software")
Utilizing the lifespan
context supervisor, the app hundreds serialized artifacts:
- XGBoost mannequin (
churn_model.pkl
) - Ordinal encoders for
gender
,subscription_type
, andcontract_length
These are made globally accessible throughout the app, making certain they’re loaded as soon as at startup.
2. Enter Knowledge Schema
class CustomerData(BaseModel):
age: int = Area(..., ge=18, le=100, description="Buyer age (18-100)")
gender: str = Area(..., description="Buyer gender")
tenure: int = Area(..., ge=0, description="Tenure in months")
utilization: float = Area(..., ge=0, description="Utilization quantity")
support_calls: int = Area(..., ge=0, description="Variety of assist calls")
payment_delay: int = Area(..., ge=0, description="Cost delay in days")
subscription_type: str = Area(..., description="Kind of subscription")
contract_length: str = Area(..., description="Contract size")
total_spend: float = Area(..., ge=0, description="Complete quantity spent")
last_interaction: int = Area(..., ge=0, description="Days since final interplay")class Config:
schema_extra = {
"instance": {
"age": 35,
"gender": "Male",
"tenure": 24,
"utilization": 450.5,
"support_calls": 3,
"payment_delay": 2,
"subscription_type": "Premium",
"contract_length": "Annual",
"total_spend": 1250.75,
"last_interaction": 7
}
}
class PredictionResponse(BaseModel):
churn_probability: float = Area(..., description="Chance of buyer churn (0-1)")
churn_prediction: str = Area(..., description="Churn prediction (Sure/No)")
risk_level: str = Area(..., description="Threat degree (Low/Medium/Excessive)")
CustomerData
(by way of Pydantic
) defines the anticipated enter for predictions, full with:
- Kind constraints (e.g., age between 18 and 100)
- Descriptions for OpenAPI documentation
- Instance payload
3. Preprocessing Pipeline
def preprocess_data(knowledge: CustomerData) -> pd.DataFrame:
"""Preprocess enter knowledge by making use of encoders and creating DataFrame"""data_dict = knowledge.model_dump()
strive:
gender_encoded = GENDER_ENCODER.rework(np.array([data_dict['gender']]).reshape(-1, 1))
if hasattr(gender_encoded, 'toarray'): # OneHotEncoder case
gender_encoded = gender_encoded.toarray()[0]
data_dict['gender'] = gender_encoded[0] if len(gender_encoded) == 1 else gender_encoded
else:
data_dict['gender'] = int(gender_encoded[0])
subscription_encoded = SUBSCRIPTION_ENCODER.rework(np.array([data_dict['subscription_type']]).reshape(-1, 1))
if hasattr(subscription_encoded, 'toarray'):
subscription_encoded = subscription_encoded.toarray()[0]
data_dict['subscription_type'] = subscription_encoded[0] if len(subscription_encoded) == 1 else subscription_encoded
else:
data_dict['subscription_type'] = int(subscription_encoded[0])
contract_encoded = CONTRACT_ENCODER.rework(np.array([data_dict['contract_length']]).reshape(-1, 1))
if hasattr(contract_encoded, 'toarray'):
contract_encoded = contract_encoded.toarray()[0]
data_dict['contract_length'] = contract_encoded[0] if len(contract_encoded) == 1 else contract_encoded
else:
data_dict['contract_length'] = int(contract_encoded[0])
besides ValueError as e:
increase HTTPException(status_code=400, element=f"Invalid categorical worth: {str(e)}")
besides Exception as e:
increase HTTPException(status_code=400, element=f"Encoding error: {str(e)}")
feature_order = [
'age', 'gender', 'tenure', 'usage',
'support_calls', 'payment_delay', 'subscription_type',
'contract_length', 'total_spend', 'last_interaction'
]
processed_data = {
'age': int(data_dict['age']),
'gender': data_dict['gender'],
'tenure': int(data_dict['tenure']),
'utilization': float(data_dict['usage']),
'support_calls': int(data_dict['support_calls']),
'payment_delay': int(data_dict['payment_delay']),
'subscription_type': data_dict['subscription_type'],
'contract_length': data_dict['contract_length'],
'total_spend': float(data_dict['total_spend']),
'last_interaction': int(data_dict['last_interaction'])
}
df = pd.DataFrame([processed_data], columns=feature_order)
logger.information(f"DataFrame dtypes: {df.dtypes.to_dict()}")
return df
The preprocess_data
perform:
- Converts person enter right into a dictionary
- Applies the loaded encoders to categorical options
- Reconstructs a constant pandas
DataFrame
within the required column order
This ensures that runtime inputs match the format used throughout mannequin coaching.
4. Prediction Endpoint
@app.publish("/predict", response_model=PredictionResponse)
async def predict_churn(customer_data: CustomerData):
"""Predict buyer churn likelihood"""if MODEL is None:
increase HTTPException(status_code=500, element="Mannequin not loaded")
strive:
processed_data = preprocess_data(customer_data)
logger.information(f"Processed knowledge form: {processed_data.form}")
logger.information(f"Processed knowledge dtypes: {processed_data.dtypes.to_dict()}")
strive:
churn_probability = MODEL.predict_proba(processed_data)[0][1] # Chance of churn (class 1)
besides Exception as xgb_error:
logger.warning(f"Commonplace prediction failed: {xgb_error}")
churn_probability = MODEL.predict_proba(processed_data.values)[0][1]
churn_prediction = "Sure" if churn_probability > 0.5 else "No"
risk_level = get_risk_level(churn_probability)
logger.information(f"Prediction made - Chance: {churn_probability:.4f}, Prediction: {churn_prediction}")
return PredictionResponse(
churn_probability=spherical(churn_probability, 4),
churn_prediction=churn_prediction,
risk_level=risk_level
)
besides Exception as e:
logger.error(f"Prediction error: {str(e)}")
increase HTTPException(status_code=500, element=f"Prediction failed: {str(e)}")
The /predict
POST endpoint:
- Accepts a JSON payload conforming to
CustomerData
- Preprocesses the enter
- Makes use of
MODEL.predict_proba()
to acquire churn likelihood - Classifies the output as “Sure”/”No” and assigns a danger degree
The result’s returned in a structured format by way of the PredictionResponse
mannequin.
5. Utility Endpoints
@app.get("/")
async def root():
return {"message": "Buyer Churn Prediction API", "standing": "wholesome"}@app.get("/well being")
async def health_check() -> dict:
return {
"standing": "wholesome",
"model_loaded": MODEL just isn't None,
"encoders_loaded": all([
GENDER_ENCODER is not None,
SUBSCRIPTION_ENCODER is not None,
CONTRACT_ENCODER is not None
])
}
/
: Root endpoint for a welcome message/well being
: Diagnostic endpoint to verify if the mannequin and encoders are loaded accurately
6. Deployment
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8082)
The app is served on 0.0.0.0:8082
, making it appropriate for Docker or cloud deployment. Logging is built-in all through for transparency and debugging.
The Dockerfile
units up a clear, production-ready atmosphere to containerize and deploy the FastAPI-based churn prediction API.
# Use official Python picture as base
FROM python:3.10-slim# Set workdir
WORKDIR /app
# System dependencies
RUN apt-get replace && apt-get set up -y gcc g++ curl && rm -rf /var/lib/apt/lists/*
# Set up dependencies
COPY necessities.txt .
RUN pip set up --no-cache-dir -r necessities.txt
# Copy FastAPI
COPY app.py .
# Copy fashions
COPY fashions/ /app/fashions/
# Expose FastAPI port
EXPOSE 8082
# Default is to run FastAPI, can override with CMD in compose
CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "8082"]
The test_api.py
script gives a light-weight, practical take a look at suite for verifying the operational readiness and correctness of your deployed churn prediction API.
import requests
import jsonBASE_URL = "http://localhost:8082"
def test_health_check():
"""Check well being examine endpoint"""
print("Testing well being examine...")
response = requests.get(f"{BASE_URL}/well being")
print(f"Standing Code: {response.status_code}")
print(f"Response: {response.json()}")
print("-" * 50)
def test_single_prediction():
"""Check single buyer prediction"""
print("Testing single prediction...")
customer_data = {
"age": 35,
"gender": "Male",
"tenure": 24,
"utilization": 450.5,
"support_calls": 3,
"payment_delay": 2,
"subscription_type": "Premium",
"contract_length": "Annual",
"total_spend": 1250.75,
"last_interaction": 7
}
response = requests.publish(f"{BASE_URL}/predict", json=customer_data)
print(f"Standing Code: {response.status_code}")
print(f"Response: {json.dumps(response.json(), indent=2)}")
print("-" * 50)
if __name__ == "__main__":
print("Beginning API exams...")
print("=" * 50)
strive:
test_health_check()
test_single_prediction()
print("All exams accomplished!")
besides requests.exceptions.ConnectionError:
print("Error: Couldn't hook up with API. Make sure that the service is working on localhost:8082")
besides Exception as e:
print(f"Error throughout testing: {str(e)}")
Information are organized like this:
challenge/
├── app.py
├── fashions/
│ ├── churn_model.pkl
│ ├── gender_encoder.pkl
│ ├── subscription_type_encoder.pkl
│ └── contract_length_encoder.pkl
├── necessities.txt
├── Dockerfile
└── test_api.py
We run this from the foundation of your challenge listing (the place the Dockerfile is positioned):
docker construct -t churn-api .
docker run -d -p 8082:8082 --name churn-api-container churn-api
curl http://localhost:8082/well being