2023-10-09 07:51:38 +00:00
|
|
|
import numpy as np
|
2023-09-23 05:04:06 +00:00
|
|
|
import random
|
|
|
|
import click
|
2023-10-09 07:51:38 +00:00
|
|
|
import pathlib
|
2023-09-23 05:04:06 +00:00
|
|
|
|
|
|
|
def evaluate(w, p):
|
|
|
|
result = sum(map(lambda s: s[0] * s[1], zip(w, p)))
|
|
|
|
return result
|
|
|
|
|
|
|
|
@click.command()
|
2023-10-09 07:51:38 +00:00
|
|
|
@click.argument('data_path')
|
|
|
|
@click.argument('label_path')
|
|
|
|
@click.argument('out_path')
|
|
|
|
def generate_test_data(data_path: str, label_path: str, out_path: str):
|
|
|
|
with open(data_path, "r") as f:
|
|
|
|
desc = f.readline().strip()
|
|
|
|
rows, dimensions = map(int, desc.split())
|
|
|
|
data = np.loadtxt(f)
|
|
|
|
print("loaded data")
|
|
|
|
|
|
|
|
with open(label_path, "r") as f:
|
|
|
|
desc = f.readline().strip()
|
|
|
|
rows = int(desc)
|
|
|
|
labels = np.loadtxt(f)
|
|
|
|
print("loaded labels")
|
|
|
|
|
|
|
|
indices = list(range(rows))
|
|
|
|
random.shuffle(indices)
|
|
|
|
split_at = int(0.7 * rows)
|
|
|
|
train_indices = indices[:split_at]
|
|
|
|
test_indices = indices[split_at:]
|
|
|
|
# print("WTF?", train_indices, test_indices)
|
|
|
|
|
|
|
|
train_data = data[train_indices,:]
|
|
|
|
train_label = labels[train_indices]
|
|
|
|
|
|
|
|
test_data = data[test_indices,:]
|
|
|
|
test_label = labels[test_indices]
|
|
|
|
|
|
|
|
out_path2 = pathlib.Path(out_path)
|
|
|
|
out_path2.mkdir(exist_ok=True, parents=True)
|
|
|
|
|
|
|
|
with open(out_path2 / "train_data.txt", "w") as f:
|
|
|
|
f.write(f"{len(train_data)} {dimensions}\n")
|
|
|
|
for row in train_data:
|
|
|
|
for i, cell in enumerate(row):
|
|
|
|
if i > 0: f.write(" ")
|
|
|
|
f.write(str(cell))
|
|
|
|
f.write("\n")
|
|
|
|
with open(out_path2 / "test_data.txt", "w") as f:
|
|
|
|
f.write(f"{len(test_data)} {dimensions}\n")
|
|
|
|
for row in test_data:
|
|
|
|
for i, cell in enumerate(row):
|
|
|
|
if i > 0: f.write(" ")
|
|
|
|
f.write(str(cell))
|
|
|
|
f.write("\n")
|
|
|
|
|
|
|
|
with open(out_path2 / "train_label.txt", "w") as f:
|
|
|
|
f.write(f"{len(train_label)}\n")
|
|
|
|
for cell in train_label:
|
|
|
|
f.write(f"{cell}\n")
|
|
|
|
with open(out_path2 / "test_label.txt", "w") as f:
|
|
|
|
f.write(f"{len(test_label)}\n")
|
|
|
|
for cell in test_label:
|
|
|
|
f.write(f"{cell}\n")
|
2023-09-23 05:04:06 +00:00
|
|
|
|
|
|
|
if __name__ == '__main__':
|
|
|
|
generate_test_data()
|