Rename, create new dataset
This commit is contained in:
parent
67403ebbdf
commit
6925b7b741
1 changed files with 27 additions and 6 deletions
|
@ -7,15 +7,26 @@ import numpy as np
|
|||
|
||||
from perceptron import Perceptron
|
||||
|
||||
|
||||
def main():
|
||||
nb_data_points = 600
|
||||
nb_training_points = int(nb_data_points * 0.8)
|
||||
|
||||
# Génération des données d'entraînement et de test
|
||||
data = np.random.rand(nb_data_points, 2)
|
||||
expected = np.array([1 if x[0] + x[1] > 0.7 else 0 for x in data])
|
||||
data1 = np.random.rand(nb_data_points // 2, 2) + 3
|
||||
expected1 = np.array([1 for _ in data1])
|
||||
data0 = np.random.rand(nb_data_points // 2, 2) + 1
|
||||
expected0 = np.array([0 for _ in data0])
|
||||
data = np.concatenate((data1, data0))
|
||||
expected = np.concatenate((expected1, expected0))
|
||||
permutation = np.random.permutation(nb_data_points)
|
||||
data = data[permutation]
|
||||
expected = expected[permutation]
|
||||
|
||||
# Affichage des points
|
||||
plot_dots(data, expected)
|
||||
|
||||
training_data = data[:nb_training_points]
|
||||
training_expected = expected[:nb_training_points]
|
||||
testing_data = data[nb_training_points:]
|
||||
testing_expected = expected[nb_training_points:]
|
||||
|
||||
|
@ -25,9 +36,7 @@ def main():
|
|||
|
||||
# Classement sur les données de test
|
||||
predicted = [p.predict(x) for x in testing_data]
|
||||
|
||||
# Affichage des points
|
||||
plot_dots(data, expected)
|
||||
print(error_rate(predicted, testing_expected))
|
||||
|
||||
# Affichage de la ligne de séparation
|
||||
x = np.linspace(0, 1, 50)
|
||||
|
@ -37,7 +46,19 @@ def main():
|
|||
plt.show()
|
||||
|
||||
|
||||
def error_rate(predicted, expected):
|
||||
"""Calculate the error rate of a prediction set"""
|
||||
|
||||
assert len(predicted) == len(expected)
|
||||
errors = 0
|
||||
for i, item in enumerate(predicted):
|
||||
if item != expected[i]:
|
||||
errors += 1
|
||||
return errors / len(predicted)
|
||||
|
||||
|
||||
def plot_dots(data, expected):
|
||||
"""Plot all dots in the data"""
|
||||
# abscisses de la classe 1
|
||||
xpoints_1 = [x[0] for i, x in enumerate(data) if expected[i] == 1]
|
||||
# ordonnées de la classe 1
|
Loading…
Reference in a new issue