optimized inference

Backtrack class
This commit is contained in:
2025-06-22 03:20:31 +02:00
parent 1f93038237
commit 54c7d351a7

149
P2.py
View File

@@ -1,8 +1,9 @@
import random import random
from pickletools import read_uint1
class Field: class Field:
def __init__(self, init_state=None): def __init__(self, init_state=None, model=None):
self.state = [] self.state = []
self.domain_values = [1, 2, 3, 4, 5, 6, 7, 8] self.domain_values = [1, 2, 3, 4, 5, 6, 7, 8]
@@ -15,6 +16,8 @@ class Field:
self.threats = self.collisions(self.state) self.threats = self.collisions(self.state)
self.fitness = 28 - self.threats self.fitness = 28 - self.threats
self.model = model
def get_fitness(self): def get_fitness(self):
return self.fitness return self.fitness
@@ -37,7 +40,7 @@ class Field:
def set_domain_values(self, new_domain): def set_domain_values(self, new_domain):
self.domain_values = new_domain self.domain_values = new_domain
def add_state(self, row): def add_queen(self, row):
if len(self.get_state()) < 8: if len(self.get_state()) < 8:
self.set_state(len(self.get_state()), row) self.set_state(len(self.get_state()), row)
@@ -96,11 +99,20 @@ class Field:
print(" └───┴───┴───┴───┴───┴───┴───┴───┘") print(" └───┴───┴───┴───┴───┴───┴───┴───┘")
print(" A B C D E F G H \n") print(" A B C D E F G H \n")
def calc(self):
if self.model == "genetic":
best_state = Genetic().calc()
self.state = best_state.get_state()
self.print_field()
elif self.model == "backtrack":
Backtrack().calc()
class Genetic: class Genetic:
def __init__(self, size=100): def __init__(self, size=1000):
self.initial_population = [] self.initial_population = []
self.p_mutation = 0 self.p_mutation = 0.1
for i in range(size): for i in range(size):
self.initial_population.append(Field()) self.initial_population.append(Field())
@@ -163,93 +175,96 @@ class Genetic:
if best_field.get_fitness() == 28: if best_field.get_fitness() == 28:
break break
print(f"{i} {best_field.get_state()} {best_field.get_fitness()}")
if best_field.get_fitness() == 28:
break
current_population = new_population current_population = new_population
new_population = [] new_population = []
return best_field return best_field
def calc(self, n=100):
def consistency(field, new_row): best_genetic_field = self.genetic_algorithm(n)
current_state = field.get_state().copy() return best_genetic_field
current_state.append(new_row)
new_field = Field(current_state)
if new_field.threats > 0:
return False
else:
return True
def inference(field): class Backtrack:
inferences = [] def __init__(self):
column = len(field.get_state()) - 1 self.results = []
state = field.get_state().copy()
row = state[column]
for i in range(len(field.get_state()), 8): def consistency(self, field, new_row):
removed_values = [] current_state = field.get_state().copy()
for new_row in field.get_domain_values(): current_state.append(new_row)
new_state = state.copy() new_field = Field(current_state)
new_state.append(new_row)
new_field = Field(new_state)
if new_field.threats > 0:
removed_values.append(new_row)
for i in removed_values: if new_field.threats > 0:
if i in field.get_domain_values(): return False
field.get_domain_values().remove(i) else:
return True
inferences.extend(removed_values) def inference(self, field):
if len(field.get_state()) >= 8:
return True
field.set_domain_values([1, 2, 3, 4, 5, 6, 7, 8]) # Reset für jede Spalte
inferences = []
# print(field.get_state())
# print(field.get_domain_values())
for new_row in range(1, 9):
if not self.consistency(field, new_row):
inferences.append(new_row)
for row in inferences:
if row in field.get_domain_values():
field.get_domain_values().remove(row)
# print(inferences)
# print(f"{field.get_domain_values()}\n")
if len(field.get_domain_values()) == 0: if len(field.get_domain_values()) == 0:
return False, inferences return False
return True, inferences return True
def backtracing(self, field):
if len(field.get_state()) == 8:
return [Field(field.get_state().copy())]
def backtracing(field): solutions = []
if len(field.get_state()) == 8:
return [Field(field.get_state().copy())]
solutions = [] for row in field.get_domain_values():
iter_domain = field.get_domain_values().copy() old_domain_values = field.get_domain_values().copy()
if self.consistency(field, row):
field.add_queen(row)
if self.inference(field): # nur für die nächste Spalte
result = self.backtracing(field)
if len(result) != 0:
solutions.extend(result)
for row in iter_domain: field.get_state().pop()
if consistency(field, row): field.domain_values = old_domain_values
field.add_state(row)
result = backtracing(field)
if len(result) != 0:
solutions.extend(result)
field.get_state().pop() return solutions
return solutions def calc(self):
for i in range(1, 9):
result = self.backtracing(Field([i]))
def backtracing_helper(field): self.results.extend(result)
results = backtracing(field) for i, result in enumerate(self.results):
print(f"{i + 1} {result.get_state()}")
return results
def main(): def main():
new_field = Field( gen_field = Field(model="genetic")
init_state=[]) # [8, 4, 5, 4, 4, 3, 7, 6] [5, 5, 5, 5, 1, 2, 8, 5] [6,3,5,7,1,4,2,8] gen_field.calc()
new_field.print_field()
print(new_field.collisions())
print("Backtrack Algorithm") back_field = Field(model="backtrack")
results = [] back_field.calc()
for i in range(1, 9):
results.extend(backtracing_helper(Field([i])))
for i, result in enumerate(results):
print(f"{i+1} {result.get_state()}")
print("Genetic Algorithm") myField = Field()
genetic = Genetic(500) myField.print_field()
best_genetic_field = genetic.genetic_algorithm(100)
best_genetic_field.print_field()
print(best_genetic_field.get_fitness())
main() main()