982 lines
61 KiB
Plaintext
982 lines
61 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "code",
|
||
"id": "initial_id",
|
||
"metadata": {
|
||
"collapsed": true,
|
||
"ExecuteTime": {
|
||
"end_time": "2025-05-17T12:44:09.056878Z",
|
||
"start_time": "2025-05-17T12:44:06.294335Z"
|
||
}
|
||
},
|
||
"source": [
|
||
"import matplotlib.pyplot as plt\n",
|
||
"import seaborn as sns\n",
|
||
"import pandas as pd\n",
|
||
"\n",
|
||
"try:\n",
|
||
" df = pd.read_csv(\"hf://datasets/schooly/online-shoppers-intention/online_shoppers_intention.csv\")\n",
|
||
"except FileNotFoundError:\n",
|
||
" print(f\"错误: 数据集联网加载失败\")\n",
|
||
" exit()\n",
|
||
"\n",
|
||
"# --- 初步数据探索 ---\n",
|
||
"print(\"--- 数据集概览 ---\")\n",
|
||
"print(df.head())\n",
|
||
"print(\"\\n--- 数据信息 ---\")\n",
|
||
"df.info()\n",
|
||
"print(\"\\n--- 描述性统计 ---\")\n",
|
||
"print(df.describe())\n",
|
||
"print(\"\\n--- 缺失值检查 ---\")\n",
|
||
"print(df.isnull().sum())\n",
|
||
"\n",
|
||
"# 目标变量分布\n",
|
||
"print(\"\\n--- 目标变量 'Revenue' 分布 ---\")\n",
|
||
"print(df['Revenue'].value_counts(normalize=True))\n",
|
||
"sns.countplot(x='Revenue', data=df)\n",
|
||
"plt.title('Revenue Distribution')\n",
|
||
"plt.show()"
|
||
],
|
||
"outputs": [
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"/home/grtsinry43/.conda/envs/ml/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
||
" from .autonotebook import tqdm as notebook_tqdm\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"--- 数据集概览 ---\n",
|
||
" Administrative Administrative_Duration Informational \\\n",
|
||
"0 0 0.0 0 \n",
|
||
"1 0 0.0 0 \n",
|
||
"2 0 0.0 0 \n",
|
||
"3 0 0.0 0 \n",
|
||
"4 0 0.0 0 \n",
|
||
"\n",
|
||
" Informational_Duration ProductRelated ProductRelated_Duration \\\n",
|
||
"0 0.0 1 0.000000 \n",
|
||
"1 0.0 2 64.000000 \n",
|
||
"2 0.0 1 0.000000 \n",
|
||
"3 0.0 2 2.666667 \n",
|
||
"4 0.0 10 627.500000 \n",
|
||
"\n",
|
||
" BounceRates ExitRates PageValues SpecialDay Month OperatingSystems \\\n",
|
||
"0 0.20 0.20 0.0 0.0 Feb 1 \n",
|
||
"1 0.00 0.10 0.0 0.0 Feb 2 \n",
|
||
"2 0.20 0.20 0.0 0.0 Feb 4 \n",
|
||
"3 0.05 0.14 0.0 0.0 Feb 3 \n",
|
||
"4 0.02 0.05 0.0 0.0 Feb 3 \n",
|
||
"\n",
|
||
" Browser Region TrafficType VisitorType Weekend Revenue \n",
|
||
"0 1 1 1 Returning_Visitor False False \n",
|
||
"1 2 1 2 Returning_Visitor False False \n",
|
||
"2 1 9 3 Returning_Visitor False False \n",
|
||
"3 2 2 4 Returning_Visitor False False \n",
|
||
"4 3 1 4 Returning_Visitor True False \n",
|
||
"\n",
|
||
"--- 数据信息 ---\n",
|
||
"<class 'pandas.core.frame.DataFrame'>\n",
|
||
"RangeIndex: 12330 entries, 0 to 12329\n",
|
||
"Data columns (total 18 columns):\n",
|
||
" # Column Non-Null Count Dtype \n",
|
||
"--- ------ -------------- ----- \n",
|
||
" 0 Administrative 12330 non-null int64 \n",
|
||
" 1 Administrative_Duration 12330 non-null float64\n",
|
||
" 2 Informational 12330 non-null int64 \n",
|
||
" 3 Informational_Duration 12330 non-null float64\n",
|
||
" 4 ProductRelated 12330 non-null int64 \n",
|
||
" 5 ProductRelated_Duration 12330 non-null float64\n",
|
||
" 6 BounceRates 12330 non-null float64\n",
|
||
" 7 ExitRates 12330 non-null float64\n",
|
||
" 8 PageValues 12330 non-null float64\n",
|
||
" 9 SpecialDay 12330 non-null float64\n",
|
||
" 10 Month 12330 non-null object \n",
|
||
" 11 OperatingSystems 12330 non-null int64 \n",
|
||
" 12 Browser 12330 non-null int64 \n",
|
||
" 13 Region 12330 non-null int64 \n",
|
||
" 14 TrafficType 12330 non-null int64 \n",
|
||
" 15 VisitorType 12330 non-null object \n",
|
||
" 16 Weekend 12330 non-null bool \n",
|
||
" 17 Revenue 12330 non-null bool \n",
|
||
"dtypes: bool(2), float64(7), int64(7), object(2)\n",
|
||
"memory usage: 1.5+ MB\n",
|
||
"\n",
|
||
"--- 描述性统计 ---\n",
|
||
" Administrative Administrative_Duration Informational \\\n",
|
||
"count 12330.000000 12330.000000 12330.000000 \n",
|
||
"mean 2.315166 80.818611 0.503569 \n",
|
||
"std 3.321784 176.779107 1.270156 \n",
|
||
"min 0.000000 0.000000 0.000000 \n",
|
||
"25% 0.000000 0.000000 0.000000 \n",
|
||
"50% 1.000000 7.500000 0.000000 \n",
|
||
"75% 4.000000 93.256250 0.000000 \n",
|
||
"max 27.000000 3398.750000 24.000000 \n",
|
||
"\n",
|
||
" Informational_Duration ProductRelated ProductRelated_Duration \\\n",
|
||
"count 12330.000000 12330.000000 12330.000000 \n",
|
||
"mean 34.472398 31.731468 1194.746220 \n",
|
||
"std 140.749294 44.475503 1913.669288 \n",
|
||
"min 0.000000 0.000000 0.000000 \n",
|
||
"25% 0.000000 7.000000 184.137500 \n",
|
||
"50% 0.000000 18.000000 598.936905 \n",
|
||
"75% 0.000000 38.000000 1464.157214 \n",
|
||
"max 2549.375000 705.000000 63973.522230 \n",
|
||
"\n",
|
||
" BounceRates ExitRates PageValues SpecialDay \\\n",
|
||
"count 12330.000000 12330.000000 12330.000000 12330.000000 \n",
|
||
"mean 0.022191 0.043073 5.889258 0.061427 \n",
|
||
"std 0.048488 0.048597 18.568437 0.198917 \n",
|
||
"min 0.000000 0.000000 0.000000 0.000000 \n",
|
||
"25% 0.000000 0.014286 0.000000 0.000000 \n",
|
||
"50% 0.003112 0.025156 0.000000 0.000000 \n",
|
||
"75% 0.016813 0.050000 0.000000 0.000000 \n",
|
||
"max 0.200000 0.200000 361.763742 1.000000 \n",
|
||
"\n",
|
||
" OperatingSystems Browser Region TrafficType \n",
|
||
"count 12330.000000 12330.000000 12330.000000 12330.000000 \n",
|
||
"mean 2.124006 2.357097 3.147364 4.069586 \n",
|
||
"std 0.911325 1.717277 2.401591 4.025169 \n",
|
||
"min 1.000000 1.000000 1.000000 1.000000 \n",
|
||
"25% 2.000000 2.000000 1.000000 2.000000 \n",
|
||
"50% 2.000000 2.000000 3.000000 2.000000 \n",
|
||
"75% 3.000000 2.000000 4.000000 4.000000 \n",
|
||
"max 8.000000 13.000000 9.000000 20.000000 \n",
|
||
"\n",
|
||
"--- 缺失值检查 ---\n",
|
||
"Administrative 0\n",
|
||
"Administrative_Duration 0\n",
|
||
"Informational 0\n",
|
||
"Informational_Duration 0\n",
|
||
"ProductRelated 0\n",
|
||
"ProductRelated_Duration 0\n",
|
||
"BounceRates 0\n",
|
||
"ExitRates 0\n",
|
||
"PageValues 0\n",
|
||
"SpecialDay 0\n",
|
||
"Month 0\n",
|
||
"OperatingSystems 0\n",
|
||
"Browser 0\n",
|
||
"Region 0\n",
|
||
"TrafficType 0\n",
|
||
"VisitorType 0\n",
|
||
"Weekend 0\n",
|
||
"Revenue 0\n",
|
||
"dtype: int64\n",
|
||
"\n",
|
||
"--- 目标变量 'Revenue' 分布 ---\n",
|
||
"Revenue\n",
|
||
"False 0.845255\n",
|
||
"True 0.154745\n",
|
||
"Name: proportion, dtype: float64\n"
|
||
]
|
||
},
|
||
{
|
||
"data": {
|
||
"text/plain": [
|
||
"<Figure size 640x480 with 1 Axes>"
|
||
],
|
||
"image/png": ""
|
||
},
|
||
"metadata": {},
|
||
"output_type": "display_data"
|
||
}
|
||
],
|
||
"execution_count": 1
|
||
},
|
||
{
|
||
"metadata": {
|
||
"ExecuteTime": {
|
||
"end_time": "2025-05-17T12:44:17.327021Z",
|
||
"start_time": "2025-05-17T12:44:17.235025Z"
|
||
}
|
||
},
|
||
"cell_type": "code",
|
||
"source": [
|
||
"from sklearn.model_selection import train_test_split\n",
|
||
"from sklearn.preprocessing import StandardScaler, OneHotEncoder\n",
|
||
"from sklearn.compose import ColumnTransformer\n",
|
||
"\n",
|
||
"# --- 数据预处理 ---\n",
|
||
"# 将布尔值转换为整数\n",
|
||
"df['Weekend'] = df['Weekend'].astype(int)\n",
|
||
"df['Revenue'] = df['Revenue'].astype(int) # 目标变量\n",
|
||
"\n",
|
||
"# 识别类别特征和数值特征\n",
|
||
"categorical_features = ['Month', 'VisitorType', 'OperatingSystems', 'Browser', 'Region', 'TrafficType']\n",
|
||
"# 'OperatingSystems', 'Browser', 'Region', 'TrafficType' 是数值类型,但它们代表类别,所以也当类别处理\n",
|
||
"# 确保将这些数值型类别特征转换为字符串类型,以便OneHotEncoder正确处理\n",
|
||
"for col in ['OperatingSystems', 'Browser', 'Region', 'TrafficType']:\n",
|
||
" df[col] = df[col].astype(str)\n",
|
||
"\n",
|
||
"numerical_features = ['Administrative', 'Administrative_Duration', 'Informational',\n",
|
||
" 'Informational_Duration', 'ProductRelated', 'ProductRelated_Duration',\n",
|
||
" 'BounceRates', 'ExitRates', 'PageValues', 'SpecialDay']\n",
|
||
"\n",
|
||
"# 创建预处理器\n",
|
||
"# 对于数值特征:进行标准化\n",
|
||
"# 对于类别特征:进行独热编码 (One-Hot Encoding)\n",
|
||
"preprocessor = ColumnTransformer(\n",
|
||
" transformers=[\n",
|
||
" ('num', StandardScaler(), numerical_features),\n",
|
||
" ('cat', OneHotEncoder(handle_unknown='ignore'), categorical_features)\n",
|
||
" ],\n",
|
||
" remainder='passthrough' # 保留其他未指定列 (如 'Weekend', 'Revenue')\n",
|
||
")\n",
|
||
"\n",
|
||
"# 分离特征和目标变量\n",
|
||
"X = df.drop('Revenue', axis=1)\n",
|
||
"y = df['Revenue']\n",
|
||
"\n",
|
||
"# 应用预处理\n",
|
||
"# 注意:ColumnTransformer 会改变列的顺序和数量\n",
|
||
"X_processed = preprocessor.fit_transform(X)\n",
|
||
"\n",
|
||
"# 如果 X_processed 是稀疏矩阵,转换为密集数组\n",
|
||
"if hasattr(X_processed, \"toarray\"):\n",
|
||
" X_processed = X_processed.toarray()\n",
|
||
"\n",
|
||
"print(f\"\\n--- 处理后的特征维度 ---\")\n",
|
||
"print(X_processed.shape)\n",
|
||
"\n",
|
||
"# 划分训练集和测试集\n",
|
||
"# 这里的 random_state 是为了结果可复现\n",
|
||
"X_train, X_test, y_train, y_test = train_test_split(X_processed, y.values, test_size=0.2, random_state=42, stratify=y)\n",
|
||
"\n",
|
||
"print(f\"训练集大小: X_train: {X_train.shape}, y_train: {y_train.shape}\")\n",
|
||
"print(f\"测试集大小: X_test: {X_test.shape}, y_test: {y_test.shape}\")"
|
||
],
|
||
"id": "1945b351cafe24fb",
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"\n",
|
||
"--- 处理后的特征维度 ---\n",
|
||
"(12330, 74)\n",
|
||
"训练集大小: X_train: (9864, 74), y_train: (9864,)\n",
|
||
"测试集大小: X_test: (2466, 74), y_test: (2466,)\n"
|
||
]
|
||
}
|
||
],
|
||
"execution_count": 2
|
||
},
|
||
{
|
||
"metadata": {},
|
||
"cell_type": "code",
|
||
"outputs": [],
|
||
"execution_count": null,
|
||
"source": [
|
||
"# --- 从零实现逻辑回归 ---\n",
|
||
"class MyLogisticRegression:\n",
|
||
" def __init__(self, learning_rate=0.01, n_iterations=1000, verbose=False):\n",
|
||
" self.learning_rate = learning_rate\n",
|
||
" self.n_iterations = n_iterations\n",
|
||
" self.weights = None\n",
|
||
" self.bias = None\n",
|
||
" self.verbose = verbose # 是否打印训练过程中的损失\n",
|
||
" self.costs = [] # 记录每次迭代的损失\n",
|
||
"\n",
|
||
" def _sigmoid(self, z):\n",
|
||
" # 防止溢出\n",
|
||
" z = np.clip(z, -500, 500)\n",
|
||
" return 1 / (1 + np.exp(-z))\n",
|
||
"\n",
|
||
" def fit(self, X, y):\n",
|
||
" n_samples, n_features = X.shape\n",
|
||
" # 初始化权重和偏置\n",
|
||
" self.weights = np.zeros(n_features)\n",
|
||
" self.bias = 0\n",
|
||
" self.costs = []\n",
|
||
"\n",
|
||
" # 梯度下降\n",
|
||
" for i in range(self.n_iterations):\n",
|
||
" # 线性模型: z = X.w + b\n",
|
||
" linear_model = np.dot(X, self.weights) + self.bias\n",
|
||
" # 应用sigmoid函数得到预测概率\n",
|
||
" y_predicted_proba = self._sigmoid(linear_model)\n",
|
||
"\n",
|
||
" # 计算梯度\n",
|
||
" dw = (1 / n_samples) * np.dot(X.T, (y_predicted_proba - y))\n",
|
||
" db = (1 / n_samples) * np.sum(y_predicted_proba - y)\n",
|
||
"\n",
|
||
" # 更新权重和偏置\n",
|
||
" self.weights -= self.learning_rate * dw\n",
|
||
" self.bias -= self.learning_rate * db\n",
|
||
"\n",
|
||
" # 计算并记录损失 (Binary Cross-Entropy)\n",
|
||
" # 添加一个小的epsilon防止log(0)\n",
|
||
" epsilon = 1e-9\n",
|
||
" cost = - (1 / n_samples) * np.sum(\n",
|
||
" y * np.log(y_predicted_proba + epsilon) + (1 - y) * np.log(1 - y_predicted_proba + epsilon))\n",
|
||
" self.costs.append(cost)\n",
|
||
"\n",
|
||
" if self.verbose and (i % (self.n_iterations // 10) == 0 or i == self.n_iterations - 1):\n",
|
||
" print(f\"Iteration {i}, Cost: {cost:.4f}\")\n",
|
||
"\n",
|
||
" def predict_proba(self, X):\n",
|
||
" linear_model = np.dot(X, self.weights) + self.bias\n",
|
||
" return self._sigmoid(linear_model)\n",
|
||
"\n",
|
||
" def predict(self, X, threshold=0.5):\n",
|
||
" y_predicted_proba = self.predict_proba(X)\n",
|
||
" y_predicted_labels = [1 if i > threshold else 0 for i in y_predicted_proba]\n",
|
||
" return np.array(y_predicted_labels)"
|
||
],
|
||
"id": "7b1a931cb634d14b"
|
||
},
|
||
{
|
||
"metadata": {},
|
||
"cell_type": "code",
|
||
"outputs": [],
|
||
"execution_count": null,
|
||
"source": [
|
||
"# --- 训练自定义逻辑回归模型 ---\n",
|
||
"print(\"\\n--- 训练自定义逻辑回归模型 ---\")\n",
|
||
"log_reg_model = MyLogisticRegression(learning_rate=0.1, n_iterations=2000, verbose=True) # 调整参数\n",
|
||
"log_reg_model.fit(X_train, y_train)\n",
|
||
"\n",
|
||
"# 绘制损失曲线\n",
|
||
"plt.figure()\n",
|
||
"plt.plot(range(len(log_reg_model.costs)), log_reg_model.costs)\n",
|
||
"plt.xlabel(\"Iteration\")\n",
|
||
"plt.ylabel(\"Cost\")\n",
|
||
"plt.title(\"Logistic Regression Training Cost\")\n",
|
||
"plt.show()\n",
|
||
"\n",
|
||
"# --- 进行预测 ---\n",
|
||
"y_pred_proba = log_reg_model.predict_proba(X_test) # 获取概率用于ROC曲线\n",
|
||
"y_pred_labels = log_reg_model.predict(X_test) # 获取类别标签"
|
||
],
|
||
"id": "ed11b643a3bf6061"
|
||
},
|
||
{
|
||
"metadata": {},
|
||
"cell_type": "code",
|
||
"outputs": [],
|
||
"execution_count": null,
|
||
"source": [
|
||
"from sklearn.metrics import accuracy_score, roc_curve, auc, precision_score, recall_score, f1_score, confusion_matrix, \\\n",
|
||
" classification_report\n",
|
||
"\n",
|
||
"print(\"\\n--- 模型评估 ---\")\n",
|
||
"accuracy = accuracy_score(y_test, y_pred_labels)\n",
|
||
"print(f\"Accuracy: {accuracy:.4f}\")\n",
|
||
"\n",
|
||
"print(\"\\nClassification Report👀:\")\n",
|
||
"print(classification_report(y_test, y_pred_labels, target_names=['Will Not Buy (0)', 'Will Buy (1)']))\n",
|
||
"\n",
|
||
"print(\"\\n混淆矩阵:\")\n",
|
||
"cm = confusion_matrix(y_test, y_pred_labels)\n",
|
||
"sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=['Will Not Buy', 'Will Buy'],\n",
|
||
" yticklabels=['Will Not Buy', 'Will Buy'])\n",
|
||
"plt.xlabel('Predicted')\n",
|
||
"plt.ylabel('Actual')\n",
|
||
"plt.title('Confusion Matrix')\n",
|
||
"plt.show()\n",
|
||
"\n",
|
||
"# ROC曲线和AUC\n",
|
||
"fpr, tpr, thresholds = roc_curve(y_test, y_pred_proba)\n",
|
||
"roc_auc = auc(fpr, tpr)\n",
|
||
"\n",
|
||
"plt.figure()\n",
|
||
"plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc:.2f})')\n",
|
||
"plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')\n",
|
||
"plt.xlim([0.0, 1.0])\n",
|
||
"plt.ylim([0.0, 1.05])\n",
|
||
"plt.xlabel('False Positive Rate')\n",
|
||
"plt.ylabel('True Positive Rate')\n",
|
||
"plt.title('ROC Curve')\n",
|
||
"plt.legend(loc=\"lower right\")\n",
|
||
"plt.show()\n",
|
||
"\n",
|
||
"print(f\"AUC: {roc_auc:.4f}\")"
|
||
],
|
||
"id": "1b9c8a29f662d051"
|
||
},
|
||
{
|
||
"metadata": {},
|
||
"cell_type": "code",
|
||
"outputs": [],
|
||
"execution_count": null,
|
||
"source": [
|
||
"import numpy as np\n",
|
||
"from tqdm import tqdm\n",
|
||
"import time\n",
|
||
"\n",
|
||
"\n",
|
||
"def linear_kernel(X1, X2):\n",
|
||
" return np.dot(X1, X2.T)\n",
|
||
"\n",
|
||
"\n",
|
||
"def rbf_kernel(X1, X2, gamma=0.1):\n",
|
||
" if X1.ndim == 1:\n",
|
||
" X1 = X1[np.newaxis, :]\n",
|
||
" if X2.ndim == 1:\n",
|
||
" X2 = X2[np.newaxis, :]\n",
|
||
" sq_dists = np.sum(X1 ** 2, axis=1)[:, None] + np.sum(X2 ** 2, axis=1) - 2 * np.dot(X1, X2.T)\n",
|
||
" return np.exp(-gamma * sq_dists)\n",
|
||
"\n",
|
||
"\n",
|
||
"class SMO_SVM:\n",
|
||
" def __init__(self, C=1.0, kernel='rbf', gamma=0.1, tol=1e-3, max_passes=5):\n",
|
||
" self.C = C\n",
|
||
" self.gamma = gamma\n",
|
||
" self.tol = tol\n",
|
||
" self.max_passes = max_passes\n",
|
||
" self.kernel = rbf_kernel if kernel == 'rbf' else linear_kernel\n",
|
||
" self.alphas = None\n",
|
||
" self.b = 0\n",
|
||
" self.X = None\n",
|
||
" self.y = None\n",
|
||
"\n",
|
||
" def fit(self, X, y):\n",
|
||
" y = np.where(y <= 0, -1, 1)\n",
|
||
" n_samples, n_features = X.shape\n",
|
||
" self.X = X\n",
|
||
" self.y = y\n",
|
||
" self.alphas = np.zeros(n_samples)\n",
|
||
" self.b = 0\n",
|
||
" passes = 0\n",
|
||
" K = self.kernel(X, X) if self.kernel != rbf_kernel else rbf_kernel(X, X, self.gamma)\n",
|
||
"\n",
|
||
" with tqdm(total=self.max_passes, desc=\"SVM Training Progress\") as pbar:\n",
|
||
" while passes < self.max_passes:\n",
|
||
" alpha_changed = 0\n",
|
||
" print(f\"\\nPass {passes + 1}/{self.max_passes}\")\n",
|
||
" start_time = time.time()\n",
|
||
" for i in range(n_samples):\n",
|
||
" Ei = self._E(i, K)\n",
|
||
" if (y[i] * Ei < -self.tol and self.alphas[i] < self.C) or (\n",
|
||
" y[i] * Ei > self.tol and self.alphas[i] > 0):\n",
|
||
" j = np.random.choice([x for x in range(n_samples) if x != i])\n",
|
||
" Ej = self._E(j, K)\n",
|
||
"\n",
|
||
" alpha_i_old = self.alphas[i].copy()\n",
|
||
" alpha_j_old = self.alphas[j].copy()\n",
|
||
"\n",
|
||
" if y[i] != y[j]:\n",
|
||
" L = max(0, self.alphas[j] - self.alphas[i])\n",
|
||
" H = min(self.C, self.C + self.alphas[j] - self.alphas[i])\n",
|
||
" else:\n",
|
||
" L = max(0, self.alphas[i] + self.alphas[j] - self.C)\n",
|
||
" H = min(self.C, self.alphas[i] + self.alphas[j])\n",
|
||
" if L == H:\n",
|
||
" continue\n",
|
||
"\n",
|
||
" eta = 2 * K[i, j] - K[i, i] - K[j, j]\n",
|
||
" if eta >= 0:\n",
|
||
" continue\n",
|
||
"\n",
|
||
" self.alphas[j] -= y[j] * (Ei - Ej) / eta\n",
|
||
" self.alphas[j] = np.clip(self.alphas[j], L, H)\n",
|
||
"\n",
|
||
" if abs(self.alphas[j] - alpha_j_old) < 1e-5:\n",
|
||
" continue\n",
|
||
"\n",
|
||
" self.alphas[i] += y[i] * y[j] * (alpha_j_old - self.alphas[j])\n",
|
||
"\n",
|
||
" b1 = self.b - Ei - y[i] * (self.alphas[i] - alpha_i_old) * K[i, i] - y[j] * (\n",
|
||
" self.alphas[j] - alpha_j_old) * K[i, j]\n",
|
||
" b2 = self.b - Ej - y[i] * (self.alphas[i] - alpha_i_old) * K[i, j] - y[j] * (\n",
|
||
" self.alphas[j] - alpha_j_old) * K[j, j]\n",
|
||
"\n",
|
||
" if 0 < self.alphas[i] < self.C:\n",
|
||
" self.b = b1\n",
|
||
" elif 0 < self.alphas[j] < self.C:\n",
|
||
" self.b = b2\n",
|
||
" else:\n",
|
||
" self.b = (b1 + b2) / 2\n",
|
||
"\n",
|
||
" alpha_changed += 1\n",
|
||
"\n",
|
||
" # # Add detailed log for each sample\n",
|
||
" # if i % 100 == 0 or i == n_samples - 1:\n",
|
||
" # print(f\" Sample {i + 1}/{n_samples}, alpha_changed: {alpha_changed}\")\n",
|
||
"\n",
|
||
" print(\n",
|
||
" f\"Pass {passes + 1} finished, alpha_changed: {alpha_changed}, time: {time.time() - start_time:.2f}s\")\n",
|
||
" if alpha_changed == 0:\n",
|
||
" passes += 1\n",
|
||
" pbar.update(1)\n",
|
||
" else:\n",
|
||
" passes = 0\n",
|
||
"\n",
|
||
" def _E(self, i, K):\n",
|
||
" return self._f(i, K) - self.y[i]\n",
|
||
"\n",
|
||
" def _f(self, i, K):\n",
|
||
" return np.sum(self.alphas * self.y * K[:, i]) + self.b\n",
|
||
"\n",
|
||
" def project(self, X):\n",
|
||
" K = self.kernel(self.X, X) if self.kernel != rbf_kernel else rbf_kernel(self.X, X, self.gamma)\n",
|
||
" return (self.alphas * self.y) @ K + self.b\n",
|
||
"\n",
|
||
" def predict(self, X):\n",
|
||
" return np.where(self.project(X) >= 0, 1, 0)"
|
||
],
|
||
"id": "6b16c3aa37baadb4"
|
||
},
|
||
{
|
||
"metadata": {},
|
||
"cell_type": "code",
|
||
"outputs": [],
|
||
"execution_count": null,
|
||
"source": [
|
||
"# --- 训练自定义线性SVM模型 ---\n",
|
||
"print(\"\\n--- 训练自定义线性SVM模型 ---\")\n",
|
||
"svm_model = SMO_SVM(C=1.0, kernel='rbf', tol=1e-3, max_passes=5)\n",
|
||
"svm_model.fit(X_train, y_train)\n",
|
||
"y_pred_labels_svm = svm_model.predict(X_test)\n",
|
||
"print(\"\\n--- SVM模型评估 ---\")\n",
|
||
"accuracy_svm = accuracy_score(y_test, y_pred_labels_svm)\n",
|
||
"print(f\"Accuracy: {accuracy_svm:.4f}\")\n",
|
||
"print(\"\\nClassification Report👀:\")\n",
|
||
"print(classification_report(y_test, y_pred_labels_svm, target_names=['Will Not Buy (0)', 'Will Buy (1)']))\n",
|
||
"print(\"\\n混淆矩阵:\")\n",
|
||
"cm_svm = confusion_matrix(y_test, y_pred_labels_svm)\n",
|
||
"sns.heatmap(cm_svm, annot=True, fmt='d', cmap='Blues', xticklabels=['Will Not Buy', 'Will Buy'],\n",
|
||
" yticklabels=['Will Not Buy', 'Will Buy'])\n",
|
||
"plt.xlabel('Predicted')\n",
|
||
"plt.ylabel('Actual')\n",
|
||
"plt.title('Confusion Matrix (SVM)')\n",
|
||
"plt.show()\n",
|
||
"# ROC曲线和AUC\n",
|
||
"fpr_svm, tpr_svm, thresholds_svm = roc_curve(y_test, svm_model.project(X_test))\n",
|
||
"roc_auc_svm = auc(fpr_svm, tpr_svm)\n",
|
||
"plt.figure()\n",
|
||
"plt.plot(fpr_svm, tpr_svm, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc_svm:.2f})')\n",
|
||
"plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')\n",
|
||
"plt.xlim([0.0, 1.0])\n",
|
||
"plt.ylim([0.0, 1.05])\n",
|
||
"plt.xlabel('False Positive Rate')\n",
|
||
"plt.ylabel('True Positive Rate')\n",
|
||
"plt.title('ROC Curve (SVM)')\n",
|
||
"plt.legend(loc=\"lower right\")\n",
|
||
"plt.show()\n",
|
||
"\n",
|
||
"print(f\"AUC: {roc_auc_svm:.4f}\")"
|
||
],
|
||
"id": "439fe97afc87ab54"
|
||
},
|
||
{
|
||
"metadata": {},
|
||
"cell_type": "code",
|
||
"outputs": [],
|
||
"execution_count": null,
|
||
"source": [
|
||
"import numpy as np\n",
|
||
"from collections import Counter\n",
|
||
"\n",
|
||
"\n",
|
||
"class MyDecisionTreeClassifier:\n",
|
||
" def __init__(self, max_depth=None, min_samples_split=2, criterion='gini'):\n",
|
||
" self.max_depth = max_depth\n",
|
||
" self.min_samples_split = min_samples_split\n",
|
||
" self.criterion = criterion # 'gini' or 'entropy'\n",
|
||
" self.tree = None\n",
|
||
"\n",
|
||
" def _calculate_impurity(self, y):\n",
|
||
" # 计算y中各类别比例\n",
|
||
" class_counts = Counter(y)\n",
|
||
" total_samples = len(y)\n",
|
||
" impurity = 0\n",
|
||
" if total_samples == 0:\n",
|
||
" return 0\n",
|
||
"\n",
|
||
" if self.criterion == 'gini':\n",
|
||
" for cls_count in class_counts.values():\n",
|
||
" p_k = cls_count / total_samples\n",
|
||
" impurity += p_k * (1 - p_k) # Gini = sum(pk * (1-pk)) or 1 - sum(pk^2)\n",
|
||
" # return 1 - impurity # if using 1 - sum(pk^2)\n",
|
||
" return impurity\n",
|
||
" elif self.criterion == 'entropy':\n",
|
||
" for cls_count in class_counts.values():\n",
|
||
" p_k = cls_count / total_samples\n",
|
||
" if p_k > 0: # log2(0) is undefined\n",
|
||
" impurity -= p_k * np.log2(p_k)\n",
|
||
" return impurity\n",
|
||
" else:\n",
|
||
" raise ValueError(\"Unknown criterion.\")\n",
|
||
"\n",
|
||
" def _calculate_information_gain(self, X_column, y, threshold):\n",
|
||
" parent_impurity = self._calculate_impurity(y)\n",
|
||
"\n",
|
||
" # 根据阈值划分数据\n",
|
||
" left_indices = X_column <= threshold\n",
|
||
" right_indices = X_column > threshold\n",
|
||
"\n",
|
||
" y_left, y_right = y[left_indices], y[right_indices]\n",
|
||
"\n",
|
||
" if len(y_left) == 0 or len(y_right) == 0:\n",
|
||
" return 0 # 如果划分导致一个子集为空,则增益为0\n",
|
||
"\n",
|
||
" n = len(y)\n",
|
||
" n_left, n_right = len(y_left), len(y_right)\n",
|
||
"\n",
|
||
" impurity_left = self._calculate_impurity(y_left)\n",
|
||
" impurity_right = self._calculate_impurity(y_right)\n",
|
||
"\n",
|
||
" child_impurity = (n_left / n) * impurity_left + (n_right / n) * impurity_right\n",
|
||
" information_gain = parent_impurity - child_impurity\n",
|
||
" return information_gain\n",
|
||
"\n",
|
||
" def _find_best_split(self, X, y):\n",
|
||
" best_gain = -1\n",
|
||
" best_feature_idx = None\n",
|
||
" best_threshold = None\n",
|
||
" n_features = X.shape[1]\n",
|
||
"\n",
|
||
" for feature_idx in range(n_features):\n",
|
||
" X_column = X[:, feature_idx]\n",
|
||
" # 对于数值特征,可能的阈值是排序后唯一值的中间点\n",
|
||
" # 简化的做法: 尝试每个唯一值作为阈值 (或它们之间的中点)\n",
|
||
" thresholds = np.unique(X_column)\n",
|
||
" if len(thresholds) > 10: # 抽样一部分阈值避免计算量过大\n",
|
||
" thresholds = np.percentile(X_column, np.arange(10, 100, 10))\n",
|
||
"\n",
|
||
" for threshold in thresholds:\n",
|
||
" gain = self._calculate_information_gain(X_column, y, threshold)\n",
|
||
" if gain > best_gain:\n",
|
||
" best_gain = gain\n",
|
||
" best_feature_idx = feature_idx\n",
|
||
" best_threshold = threshold\n",
|
||
" return best_feature_idx, best_threshold, best_gain\n",
|
||
"\n",
|
||
" def _build_tree(self, X, y, depth=0):\n",
|
||
" n_samples, n_features = X.shape\n",
|
||
" n_labels = len(np.unique(y))\n",
|
||
"\n",
|
||
" # 停止条件\n",
|
||
" if (self.max_depth is not None and depth >= self.max_depth) or \\\n",
|
||
" n_labels == 1 or \\\n",
|
||
" n_samples < self.min_samples_split:\n",
|
||
" leaf_value = Counter(y).most_common(1)[0][0] # 叶节点值为多数类\n",
|
||
" return {'value': leaf_value} # 使用字典表示叶节点\n",
|
||
"\n",
|
||
" best_feature_idx, best_threshold, best_gain = self._find_best_split(X, y)\n",
|
||
"\n",
|
||
" # 如果信息增益很小,也停止分裂 (避免过拟合)\n",
|
||
" if best_gain <= 0.001: # 可调参数\n",
|
||
" leaf_value = Counter(y).most_common(1)[0][0]\n",
|
||
" return {'value': leaf_value}\n",
|
||
"\n",
|
||
" # 划分数据集\n",
|
||
" left_indices = X[:, best_feature_idx] <= best_threshold\n",
|
||
" right_indices = X[:, best_feature_idx] > best_threshold\n",
|
||
"\n",
|
||
" X_left, y_left = X[left_indices], y[left_indices]\n",
|
||
" X_right, y_right = X[right_indices], y[right_indices]\n",
|
||
"\n",
|
||
" # 确保子集非空,如果一个子集为空,则无法继续分裂,当前节点成为叶节点\n",
|
||
" if len(y_left) == 0 or len(y_right) == 0:\n",
|
||
" leaf_value = Counter(y).most_common(1)[0][0]\n",
|
||
" return {'value': leaf_value}\n",
|
||
"\n",
|
||
" # 递归构建左右子树\n",
|
||
" left_subtree = self._build_tree(X_left, y_left, depth + 1)\n",
|
||
" right_subtree = self._build_tree(X_right, y_right, depth + 1)\n",
|
||
"\n",
|
||
" return {\n",
|
||
" 'feature_index': best_feature_idx,\n",
|
||
" 'threshold': best_threshold,\n",
|
||
" 'left': left_subtree,\n",
|
||
" 'right': right_subtree,\n",
|
||
" 'info_gain': best_gain # 可选\n",
|
||
" }\n",
|
||
"\n",
|
||
" def fit(self, X, y):\n",
|
||
" self.tree = self._build_tree(X, y)\n",
|
||
"\n",
|
||
" def _traverse_tree(self, x, node):\n",
|
||
" if 'value' in node: # 是叶节点\n",
|
||
" return node['value']\n",
|
||
"\n",
|
||
" if x[node['feature_index']] <= node['threshold']:\n",
|
||
" return self._traverse_tree(x, node['left'])\n",
|
||
" else:\n",
|
||
" return self._traverse_tree(x, node['right'])\n",
|
||
"\n",
|
||
" def predict(self, X):\n",
|
||
" return np.array([self._traverse_tree(x, self.tree) for x in X])"
|
||
],
|
||
"id": "81881edd3d6dac8c"
|
||
},
|
||
{
|
||
"metadata": {},
|
||
"cell_type": "code",
|
||
"outputs": [],
|
||
"execution_count": null,
|
||
"source": [
|
||
"# --- 训练自定义决策树模型 ---\n",
|
||
"print(\"\\n--- 训练自定义决策树模型 ---\")\n",
|
||
"tree_model = MyDecisionTreeClassifier(max_depth=5, min_samples_split=2, criterion='gini')\n",
|
||
"tree_model.fit(X_train, y_train)\n",
|
||
"y_pred_labels_tree = tree_model.predict(X_test)\n",
|
||
"print(\"\\n--- 决策树模型评估 ---\")\n",
|
||
"accuracy_tree = accuracy_score(y_test, y_pred_labels_tree)\n",
|
||
"print(f\"Accuracy: {accuracy_tree:.4f}\")\n",
|
||
"print(\"\\nClassification Report👀:\")\n",
|
||
"print(classification_report(y_test, y_pred_labels_tree, target_names=['Will Not Buy (0)', 'Will Buy (1)']))\n",
|
||
"print(\"\\n混淆矩阵:\")\n",
|
||
"cm_tree = confusion_matrix(y_test, y_pred_labels_tree)\n",
|
||
"sns.heatmap(cm_tree, annot=True, fmt='d', cmap='Blues', xticklabels=['Will Not Buy', 'Will Buy'],\n",
|
||
" yticklabels=['Will Not Buy', 'Will Buy'])\n",
|
||
"plt.xlabel('Predicted')\n",
|
||
"plt.ylabel('Actual')\n",
|
||
"plt.title('Confusion Matrix (Decision Tree)')\n",
|
||
"plt.show()\n",
|
||
"# ROC曲线和AUC\n",
|
||
"fpr_tree, tpr_tree, thresholds_tree = roc_curve(y_test, y_pred_labels_tree)\n",
|
||
"roc_auc_tree = auc(fpr_tree, tpr_tree)\n",
|
||
"plt.figure()\n",
|
||
"plt.plot(fpr_tree, tpr_tree, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc_tree:.2f})')\n",
|
||
"plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')\n",
|
||
"plt.xlim([0.0, 1.0])\n",
|
||
"plt.ylim([0.0, 1.05])\n",
|
||
"plt.xlabel('False Positive Rate')\n",
|
||
"plt.ylabel('True Positive Rate')\n",
|
||
"plt.title('ROC Curve (Decision Tree)')\n",
|
||
"plt.legend(loc=\"lower right\")\n",
|
||
"plt.show()\n",
|
||
"\n",
|
||
"print(f\"AUC: {roc_auc_tree:.4f}\")"
|
||
],
|
||
"id": "f82f5a58da4117e1"
|
||
},
|
||
{
|
||
"metadata": {},
|
||
"cell_type": "code",
|
||
"outputs": [],
|
||
"execution_count": null,
|
||
"source": [
|
||
"import numpy as np\n",
|
||
"\n",
|
||
"\n",
|
||
"class MyGaussianNaiveBayes:\n",
|
||
" def __init__(self):\n",
|
||
" self.class_priors_ = None\n",
|
||
" self.class_means_ = None\n",
|
||
" self.class_vars_ = None # 或者 stds_\n",
|
||
" self.classes_ = None\n",
|
||
" self.epsilon = 1e-9 # 防止除以零或log(0)\n",
|
||
"\n",
|
||
" def fit(self, X, y):\n",
|
||
" n_samples, n_features = X.shape\n",
|
||
" self.classes_ = np.unique(y)\n",
|
||
" n_classes = len(self.classes_)\n",
|
||
"\n",
|
||
" self.class_priors_ = np.zeros(n_classes)\n",
|
||
" self.class_means_ = np.zeros((n_classes, n_features))\n",
|
||
" self.class_vars_ = np.zeros((n_classes, n_features))\n",
|
||
"\n",
|
||
" for idx, c in enumerate(self.classes_):\n",
|
||
" X_c = X[y == c] # 取出类别c的所有样本\n",
|
||
" self.class_priors_[idx] = X_c.shape[0] / n_samples\n",
|
||
" self.class_means_[idx, :] = X_c.mean(axis=0)\n",
|
||
" self.class_vars_[idx, :] = X_c.var(axis=0) + self.epsilon # 添加epsilon防止方差为0\n",
|
||
"\n",
|
||
" def _pdf(self, class_idx, x_row): # x_row是单个样本的特征向量\n",
|
||
" mean = self.class_means_[class_idx]\n",
|
||
" var = self.class_vars_[class_idx]\n",
|
||
" # log_pdf = -0.5 * np.sum(np.log(2. * np.pi * var)) - 0.5 * np.sum(((x_row - mean) ** 2) / var)\n",
|
||
" # 直接计算概率,但要注意下溢风险,通常用log-sum-exp技巧\n",
|
||
" numerator = np.exp(-((x_row - mean) ** 2) / (2 * var))\n",
|
||
" denominator = np.sqrt(2 * np.pi * var)\n",
|
||
" return numerator / denominator # 这会返回每个特征的P(xj|yk)\n",
|
||
"\n",
|
||
" def _calculate_log_class_probability(self, class_idx, x_row):\n",
|
||
" log_prior = np.log(self.class_priors_[class_idx] + self.epsilon)\n",
|
||
"\n",
|
||
" mean = self.class_means_[class_idx]\n",
|
||
" var = self.class_vars_[class_idx] # var = std^2\n",
|
||
"\n",
|
||
" # log( P(xj | yk) ) = -log(sqrt(2*pi*var_j)) - (xj - mean_j)^2 / (2*var_j)\n",
|
||
" log_likelihood_terms = -0.5 * np.log(2 * np.pi * var) - 0.5 * ((x_row - mean) ** 2) / var\n",
|
||
" log_likelihood = np.sum(log_likelihood_terms)\n",
|
||
"\n",
|
||
" return log_prior + log_likelihood\n",
|
||
"\n",
|
||
" def predict_proba(self, X): # 返回每个类别的对数后验概率(未归一化)或归一化概率\n",
|
||
" n_samples = X.shape[0]\n",
|
||
" n_classes = len(self.classes_)\n",
|
||
" log_posteriors = np.zeros((n_samples, n_classes))\n",
|
||
"\n",
|
||
" for i in range(n_samples):\n",
|
||
" for class_idx in range(n_classes):\n",
|
||
" log_posteriors[i, class_idx] = self._calculate_log_class_probability(class_idx, X[i])\n",
|
||
"\n",
|
||
" # 归一化得到概率 (可选,如果只需要类别可以省略)\n",
|
||
" # log_sum_exp 技巧避免下溢/上溢\n",
|
||
" max_log = np.max(log_posteriors, axis=1, keepdims=True)\n",
|
||
" log_posteriors_shifted = log_posteriors - max_log\n",
|
||
" exp_log_posteriors_shifted = np.exp(log_posteriors_shifted)\n",
|
||
" sum_exp = np.sum(exp_log_posteriors_shifted, axis=1, keepdims=True)\n",
|
||
" probabilities = exp_log_posteriors_shifted / sum_exp\n",
|
||
" return probabilities\n",
|
||
"\n",
|
||
" def predict(self, X):\n",
|
||
" predictions = []\n",
|
||
" for x_row in X:\n",
|
||
" posteriors = []\n",
|
||
" for class_idx, c in enumerate(self.classes_):\n",
|
||
" log_posterior = self._calculate_log_class_probability(class_idx, x_row)\n",
|
||
" posteriors.append(log_posterior)\n",
|
||
" predictions.append(self.classes_[np.argmax(posteriors)])\n",
|
||
" return np.array(predictions)"
|
||
],
|
||
"id": "b12fa40a4de770e5"
|
||
},
|
||
{
|
||
"metadata": {},
|
||
"cell_type": "code",
|
||
"outputs": [],
|
||
"execution_count": null,
|
||
"source": [
|
||
"# --- 训练自定义朴素贝叶斯模型 ---\n",
|
||
"print(\"\\n--- 训练自定义朴素贝叶斯模型 ---\")\n",
|
||
"nb_model = MyGaussianNaiveBayes()\n",
|
||
"nb_model.fit(X_train, y_train)\n",
|
||
"y_pred_labels_nb = nb_model.predict(X_test)\n",
|
||
"print(\"\\n--- 朴素贝叶斯模型评估 ---\")\n",
|
||
"accuracy_nb = accuracy_score(y_test, y_pred_labels_nb)\n",
|
||
"print(f\"Accuracy: {accuracy_nb:.4f}\")\n",
|
||
"print(\"\\nClassification Report👀:\")\n",
|
||
"print(classification_report(y_test, y_pred_labels_nb, target_names=['Will Not Buy (0)', 'Will Buy (1)']))\n",
|
||
"print(\"\\n混淆矩阵:\")\n",
|
||
"cm_nb = confusion_matrix(y_test, y_pred_labels_nb)\n",
|
||
"sns.heatmap(cm_nb, annot=True, fmt='d', cmap='Blues', xticklabels=['Will Not Buy', 'Will Buy'],\n",
|
||
" yticklabels=['Will Not Buy', 'Will Buy'])\n",
|
||
"plt.xlabel('Predicted')\n",
|
||
"plt.ylabel('Actual')\n",
|
||
"plt.title('Confusion Matrix (Naive Bayes)')\n",
|
||
"plt.show()\n",
|
||
"# ROC曲线和AUC\n",
|
||
"fpr_nb, tpr_nb, thresholds_nb = roc_curve(y_test, y_pred_labels_nb)\n",
|
||
"roc_auc_nb = auc(fpr_nb, tpr_nb)\n",
|
||
"plt.figure()\n",
|
||
"plt.plot(fpr_nb, tpr_nb, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc_nb:.2f})')\n",
|
||
"plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')\n",
|
||
"plt.xlim([0.0, 1.0])\n",
|
||
"plt.ylim([0.0, 1.05])\n",
|
||
"plt.xlabel('False Positive Rate')\n",
|
||
"plt.ylabel('True Positive Rate')\n",
|
||
"plt.title('ROC Curve (Naive Bayes)')\n",
|
||
"plt.legend(loc=\"lower right\")\n",
|
||
"plt.show()\n",
|
||
"\n",
|
||
"print(f\"AUC: {roc_auc_nb:.4f}\")"
|
||
],
|
||
"id": "e8c167999fb13f6b"
|
||
},
|
||
{
|
||
"metadata": {},
|
||
"cell_type": "code",
|
||
"outputs": [],
|
||
"execution_count": null,
|
||
"source": [
|
||
"import numpy as np\n",
|
||
"from collections import Counter\n",
|
||
"\n",
|
||
"\n",
|
||
"class MyKNearestNeighbors:\n",
|
||
" def __init__(self, k=3):\n",
|
||
" self.k = k\n",
|
||
" self.X_train = None\n",
|
||
" self.y_train = None\n",
|
||
"\n",
|
||
" def fit(self, X_train, y_train):\n",
|
||
" self.X_train = X_train\n",
|
||
" self.y_train = y_train\n",
|
||
"\n",
|
||
" def _euclidean_distance(self, x1, x2):\n",
|
||
" return np.sqrt(np.sum((x1 - x2) ** 2))\n",
|
||
"\n",
|
||
" def _predict_single(self, x_test_sample):\n",
|
||
" distances = [self._euclidean_distance(x_test_sample, x_train_sample) for x_train_sample in self.X_train]\n",
|
||
" # 获取k个最近邻的索引\n",
|
||
" k_indices = np.argsort(distances)[:self.k]\n",
|
||
" # 获取k个最近邻的标签\n",
|
||
" k_nearest_labels = [self.y_train[i] for i in k_indices]\n",
|
||
" # 多数投票\n",
|
||
" most_common = Counter(k_nearest_labels).most_common(1)\n",
|
||
" return most_common[0][0]\n",
|
||
"\n",
|
||
" def predict(self, X_test):\n",
|
||
" predictions = [self._predict_single(x_test_sample) for x_test_sample in X_test]\n",
|
||
" return np.array(predictions)"
|
||
],
|
||
"id": "718bb29ac00a859c"
|
||
},
|
||
{
|
||
"metadata": {},
|
||
"cell_type": "code",
|
||
"outputs": [],
|
||
"execution_count": null,
|
||
"source": [
|
||
"# --- 训练自定义KNN模型 ---\n",
|
||
"print(\"\\n--- 训练自定义KNN模型 ---\")\n",
|
||
"knn_model = MyKNearestNeighbors(k=5) # k值可以调整\n",
|
||
"knn_model.fit(X_train, y_train)\n",
|
||
"y_pred_labels_knn = knn_model.predict(X_test)\n",
|
||
"print(\"\\n--- KNN模型评估 ---\")\n",
|
||
"accuracy_knn = accuracy_score(y_test, y_pred_labels_knn)\n",
|
||
"print(f\"Accuracy: {accuracy_knn:.4f}\")\n",
|
||
"print(\"\\nClassification Report👀:\")\n",
|
||
"print(classification_report(y_test, y_pred_labels_knn, target_names=['Will Not Buy (0)', 'Will Buy (1)']))\n",
|
||
"print(\"\\n混淆矩阵:\")\n",
|
||
"cm_knn = confusion_matrix(y_test, y_pred_labels_knn)\n",
|
||
"sns.heatmap(cm_knn, annot=True, fmt='d', cmap='Blues', xticklabels=['Will Not Buy', 'Will Buy'],\n",
|
||
" yticklabels=['Will Not Buy', 'Will Buy'])\n",
|
||
"plt.xlabel('Predicted')\n",
|
||
"plt.ylabel('Actual')\n",
|
||
"plt.title('Confusion Matrix (KNN)')\n",
|
||
"plt.show()\n",
|
||
"# ROC曲线和AUC\n",
|
||
"fpr_knn, tpr_knn, thresholds_knn = roc_curve(y_test, y_pred_labels_knn)\n",
|
||
"roc_auc_knn = auc(fpr_knn, tpr_knn)\n",
|
||
"plt.figure()\n",
|
||
"plt.plot(fpr_knn, tpr_knn, color='darkorange', lw=2, label=f'ROC curve (area = {roc_auc_knn:.2f})')\n",
|
||
"plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')\n",
|
||
"plt.xlim([0.0, 1.0])\n",
|
||
"plt.ylim([0.0, 1.05])\n",
|
||
"plt.xlabel('False Positive Rate')\n",
|
||
"plt.ylabel('True Positive Rate')\n",
|
||
"plt.title('ROC Curve (KNN)')\n",
|
||
"plt.legend(loc=\"lower right\")\n",
|
||
"plt.show()\n",
|
||
"\n",
|
||
"print(f\"AUC: {roc_auc_knn:.4f}\")\n"
|
||
],
|
||
"id": "fc3b39b5d46a42a"
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kernelspec": {
|
||
"display_name": "Python 3",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 2
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython2",
|
||
"version": "2.7.6"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 5
|
||
}
|