271 lines
8.5 KiB
Python
271 lines
8.5 KiB
Python
import random
|
|
from pickletools import read_uint1
|
|
|
|
|
|
class Field:
|
|
def __init__(self, init_state=None, model=None):
|
|
self.state = []
|
|
self.domain_values = [1, 2, 3, 4, 5, 6, 7, 8]
|
|
|
|
if init_state is None:
|
|
for i in range(8):
|
|
self.state.append(random.randint(1, 8)) # row number [1:8]
|
|
else:
|
|
self.state = init_state.copy()
|
|
|
|
self.threats = self.collisions(self.state)
|
|
self.fitness = 28 - self.threats
|
|
|
|
self.model = model
|
|
|
|
def get_fitness(self):
|
|
return self.fitness
|
|
|
|
def get_state(self):
|
|
return self.state
|
|
|
|
def get_domain_values(self):
|
|
return self.domain_values
|
|
|
|
# Actions
|
|
def set_state(self, column, row=None):
|
|
if row is None:
|
|
self.state[column] = random.randint(1, 8)
|
|
elif 0 < row < 9:
|
|
if column < len(self.state):
|
|
self.state[column] = row
|
|
elif column == len(self.state):
|
|
self.state.append(row)
|
|
|
|
def set_domain_values(self, new_domain):
|
|
self.domain_values = new_domain
|
|
|
|
def add_queen(self, row):
|
|
if len(self.get_state()) < 8:
|
|
self.set_state(len(self.get_state()), row)
|
|
|
|
def move_queen(self, column, new_row=None):
|
|
self.set_state(column, new_row)
|
|
# Update
|
|
self.threats = self.collisions()
|
|
self.fitness = 28 - self.threats
|
|
|
|
def move_all_queens(self, new_state=None):
|
|
if new_state is None:
|
|
for i in range(8):
|
|
self.move_queen(i)
|
|
else:
|
|
for i, new_row in enumerate(new_state):
|
|
self.move_queen(i, new_row)
|
|
|
|
# heuristics functions
|
|
def collisions(self, current_state=None):
|
|
# wagerechte haben die gleiche row zahl stehe
|
|
# diagonale haben einen wert der um den spalten-abstand gemindert ist => gleichseitiges rechtwinkliges Dreieck
|
|
# Beachte die Spalten/ Linien Nr ist um eins verringert [0, 1, ...,7]
|
|
if current_state is None:
|
|
current_state = self.get_state()
|
|
|
|
collisions = 0
|
|
for i, row_i in enumerate(current_state):
|
|
for j, row_j in enumerate(current_state):
|
|
if j is not i:
|
|
# horizontal diagonal in both sides up and down and counting "twice"
|
|
if row_i == row_j or row_j == (row_i + abs(j - i)) or row_j == (row_i - abs(j - i)):
|
|
collisions += 1
|
|
# print(f"{i+1}-{row_i} <=> {j+1}-{row_j}") # Debugging
|
|
return collisions / 2
|
|
|
|
def print_field(self):
|
|
print("\n ┌───┬───┬───┬───┬───┬───┬───┬───┐")
|
|
for row in range(8, 0, -1): # (0:8]
|
|
row_string = ""
|
|
for line in range(8):
|
|
if line < len(self.state) and row is self.state[
|
|
line]: # is there a Queen in this line (spalte) in this row
|
|
if (row + line) % 2 == 0:
|
|
row_string += "▌Q▐│"
|
|
else:
|
|
row_string += " Q │"
|
|
|
|
elif (row + line) % 2 == 0:
|
|
row_string += "███│"
|
|
else:
|
|
row_string += " │"
|
|
|
|
print(f"{row} |{row_string}")
|
|
if row > 1: print(" ├───┼───┼───┼───┼───┼───┼───┼───┤")
|
|
|
|
print(" └───┴───┴───┴───┴───┴───┴───┴───┘")
|
|
print(" A B C D E F G H \n")
|
|
|
|
def calc(self):
|
|
if self.model == "genetic":
|
|
best_state = Genetic().calc()
|
|
self.state = best_state.get_state()
|
|
self.print_field()
|
|
|
|
elif self.model == "backtrack":
|
|
Backtrack().calc()
|
|
|
|
|
|
class Genetic:
|
|
def __init__(self, size=1000):
|
|
self.initial_population = []
|
|
self.p_mutation = 0.1
|
|
|
|
for i in range(size):
|
|
self.initial_population.append(Field())
|
|
|
|
def random_selection(self, population):
|
|
"""
|
|
input:
|
|
population: a set of individuals
|
|
Fitness-FN: # of non-attacking queens (max 28)
|
|
returns:
|
|
Basierend auf der Verteilung der heuristischen Werte (Fitness) soll zufällig ein Eintrag (Field) gewählt werden, d.h. je höher der heuritische Wert (Fitness) ist, umso höher soll die Wahrscheinlichkeit sein, dass ein Field ausgewählt wird
|
|
"""
|
|
fitness = []
|
|
for field in population:
|
|
fitness.append(field.get_fitness())
|
|
|
|
chosen = random.choices(population, weights=fitness, k=1)[0]
|
|
|
|
return chosen
|
|
|
|
def mutation(self, field):
|
|
"""
|
|
input:
|
|
state: a single individuals
|
|
returns:
|
|
randomly mutated version of it
|
|
"""
|
|
field.move_queen(random.randint(0, 7), random.randint(1, 8))
|
|
|
|
def reproduce(self, x, y):
|
|
child = []
|
|
n = len(x.get_state())
|
|
c = random.randint(1, n)
|
|
|
|
child.extend(x.get_state()[:c]) # Slice operator Syntax [a:b[
|
|
child.extend(y.get_state()[c:])
|
|
|
|
return Field(child)
|
|
|
|
def genetic_algorithm(self, n):
|
|
"""
|
|
population: a set of individuals
|
|
Fitness-FN: # of non-attacking queens (max 28)
|
|
"""
|
|
current_population = self.initial_population
|
|
new_population = []
|
|
best_field = self.initial_population[0]
|
|
|
|
for i in range(n):
|
|
for j in range(len(self.initial_population)):
|
|
x = self.random_selection(current_population)
|
|
y = self.random_selection(current_population)
|
|
child = self.reproduce(x, y)
|
|
if random.random() < self.p_mutation:
|
|
self.mutation(child)
|
|
new_population.append(child)
|
|
|
|
if child.get_fitness() > best_field.get_fitness():
|
|
best_field = child
|
|
if best_field.get_fitness() == 28:
|
|
break
|
|
|
|
print(f"{i} {best_field.get_state()} {best_field.get_fitness()}")
|
|
if best_field.get_fitness() == 28:
|
|
break
|
|
|
|
current_population = new_population
|
|
new_population = []
|
|
|
|
return best_field
|
|
|
|
def calc(self, n=100):
|
|
best_genetic_field = self.genetic_algorithm(n)
|
|
return best_genetic_field
|
|
|
|
|
|
class Backtrack:
|
|
def __init__(self):
|
|
self.results = []
|
|
|
|
def consistency(self, field, new_row):
|
|
current_state = field.get_state().copy()
|
|
current_state.append(new_row)
|
|
new_field = Field(current_state)
|
|
|
|
if new_field.threats > 0:
|
|
return False
|
|
else:
|
|
return True
|
|
|
|
def inference(self, field):
|
|
if len(field.get_state()) >= 8:
|
|
return True
|
|
|
|
field.set_domain_values([1, 2, 3, 4, 5, 6, 7, 8]) # Reset für jede Spalte
|
|
inferences = []
|
|
# print(field.get_state())
|
|
# print(field.get_domain_values())
|
|
|
|
for new_row in range(1, 9):
|
|
if not self.consistency(field, new_row):
|
|
inferences.append(new_row)
|
|
|
|
for row in inferences:
|
|
if row in field.get_domain_values():
|
|
field.get_domain_values().remove(row)
|
|
|
|
# print(inferences)
|
|
# print(f"{field.get_domain_values()}\n")
|
|
|
|
if len(field.get_domain_values()) == 0:
|
|
return False
|
|
|
|
return True
|
|
|
|
def backtracing(self, field):
|
|
if len(field.get_state()) == 8:
|
|
return [Field(field.get_state().copy())]
|
|
|
|
solutions = []
|
|
|
|
for row in field.get_domain_values():
|
|
old_domain_values = field.get_domain_values().copy()
|
|
if self.consistency(field, row):
|
|
field.add_queen(row)
|
|
if self.inference(field): # nur für die nächste Spalte
|
|
result = self.backtracing(field)
|
|
if len(result) != 0:
|
|
solutions.extend(result)
|
|
|
|
field.get_state().pop()
|
|
field.domain_values = old_domain_values
|
|
|
|
return solutions
|
|
|
|
def calc(self):
|
|
for i in range(1, 9):
|
|
result = self.backtracing(Field([i]))
|
|
self.results.extend(result)
|
|
for i, result in enumerate(self.results):
|
|
print(f"{i + 1} {result.get_state()}")
|
|
|
|
|
|
def main():
|
|
gen_field = Field(model="genetic")
|
|
gen_field.calc()
|
|
|
|
back_field = Field(model="backtrack")
|
|
back_field.calc()
|
|
|
|
myField = Field()
|
|
myField.print_field()
|
|
|
|
|
|
main()
|