Skip to content

Commit f44646b

Browse files
authored
1 parent 0ccc63f commit f44646b

File tree

1 file changed

+18
-23
lines changed

1 file changed

+18
-23
lines changed

machine_learning/xgboostclassifier.py

Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,29 +12,11 @@ def data_handling(data: dict) -> tuple:
1212
return x
1313

1414

15-
def xgboost(
16-
features: list,
17-
target: list,
18-
test_features: list,
19-
test_targets: list,
20-
namesofflowers: list,
21-
) -> None:
15+
def xgboost(features: list,target: list) # -> returns a trained model:
2216
classifier = XGBClassifier()
2317
classifier.fit(features, target)
24-
# Display Confusion Matrix of Classifier
25-
# with both train and test sets
26-
plot_confusion_matrix(
27-
classifier,
28-
test_features,
29-
test_targets,
30-
display_labels=namesofflowers,
31-
cmap="Blues",
32-
normalize="true",
33-
)
34-
plt.title("Normalized Confusion Matrix - IRIS Dataset")
35-
plt.show()
36-
37-
18+
return classifier
19+
3820
def main() -> None:
3921

4022
"""
@@ -53,7 +35,20 @@ def main() -> None:
5335
)
5436

5537
# XGBoost Classifier
56-
xgboost(x_train, y_train, x_test, y_test, names)
38+
xgb=xgboost(x_train, y_train)
39+
40+
# Display Confusion Matrix of Classifier
41+
# with both train and test sets
42+
plot_confusion_matrix(
43+
classifier,
44+
x_test,
45+
y_test,
46+
display_labels=names,
47+
cmap="Blues",
48+
normalize="true",
49+
)
50+
plt.title("Normalized Confusion Matrix - IRIS Dataset")
51+
plt.show()
5752

5853

5954
if __name__ == "__main__":
@@ -62,4 +57,4 @@ def main() -> None:
6257
doctest.testmod(name="main", verbose=True)
6358
doctest.testmod(name="xgboost", verbose=True)
6459
doctest.testmod(name="data_handling", verbose=True)
65-
# main()
60+
#main()

0 commit comments

Comments
 (0)