-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
138 lines (108 loc) · 4.47 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import numpy as np
import matplotlib.pyplot as plt
from time import time
from tqdm import tqdm
from utils import get_dates, to_lonlat, transform_coords, to_canvas
from village import Village
class Model:
def __init__(self, start, params):
with open('./layers/env.asc', 'r') as f:
self.header = [next(f) for i in range(6)]
# Dimensions of the asc layers
self.width = int(self.header[0].split()[1])
self.height = int(self.header[1].split()[1])
self.nodata = int(self.header[-1].split()[1])
self.img = np.full((self.height, self.width), self.nodata)
self.agents = {}
self.grid = {}
# Start at earliest date in the model
self.bp = start
# Model parameters to be passed down to the village agents as a
# dictionary. Initial coords, fission threshold, K*, catchment,
# leap distance, permanence
self.params = params
self.xmin = float(self.header[2].split()[1])
self.ymax = float(self.header[3].split()[1]) + self.height * int(self.header[4].split()[1])
self.params['coords'] = to_canvas(params['coords'], self.xmin, self.ymax)
self.params['catchment'] //= 10
self.params['leap_distance'] //= 10
# Layers to keep track of agents, land ownership and dates of arrival
for y in range(self.height):
for x in range(self.width):
self.grid[(x, y)] = {'agent': 0, 'owner': 0, 'arrival_time': 0}
# Add layers with ecological niche
env = np.loadtxt('./layers/env.asc', skiprows=6)
for y in range(self.height):
for x in range(self.width):
self.grid[(x, y)]['env'] = env[y, x]
# Prevent water cells from being settled
if env[y, x] == self.nodata:
self.grid[(x, y)]['owner'] = -1
self.setup_agents()
def setup_agents(self):
"""
Creates a village, adds land to its territory and records its start date.
"""
village = Village(self, **self.params)
self.agents[village._id] = village
self.grid[village.coords]['agent'] = village._id
village.claim_land(village.coords)
village.record_date()
def eval(self):
"""
Returns a score from 0 to 1 of model fitness based on match with
archaeological dates.
"""
total_score = 0
dates = get_dates(self.xmin, self.ymax)
for coords in dates:
score = 0
sim_date = self.grid[coords]['arrival_time']
if sim_date and sim_date in dates[coords]:
# Normalize probability distribution
score += (dates[coords][sim_date] / max(dates[coords].values()))
total_score += score
return total_score / len(dates)
def write(self):
"""
Writes the simulated arrival times to an asc file and scores of
archaeological dates to a csv file.
"""
timestamp = int(time())
np.savetxt(f'./results/sim{str(timestamp)}.asc',
self.img, header=''.join(self.header)[:-1], comments='')
date_file = f'./results/dates{str(timestamp)}.csv'
dates = get_dates(self.xmin, self.ymax)
with open(date_file, 'w') as file:
file.write('x,y,score\n')
for coords in dates:
sim_date = self.grid[coords]['arrival_time']
if sim_date in dates[coords]:
score = (dates[coords][sim_date] / max(dates[coords].values()))
else:
score = 0
x, y = to_lonlat(transform_coords(coords, self.xmin, self.ymax))
file.write(f'{str(x)},{str(y)},{str(score)}\n')
def step(self):
agent_list = list(self.agents.keys())
for _id in agent_list:
self.agents[_id].step()
self.bp -= 1
def plot(self):
plt.rcParams['figure.dpi'] = 150
img = self.img.copy().astype('float')
img[img==self.nodata] = np.nan
plt.imshow(img)
plt.show()
def run(self, num_iter, show_prog=False, plot=False):
if show_prog:
for i in tqdm(range(num_iter)):
self.step()
else:
for i in range(num_iter):
self.step()
for (x,y) in self.grid:
if self.grid[(x,y)]['arrival_time']:
self.img[y,x] = self.grid[(x,y)]['arrival_time']
if plot:
self.plot()