csci5451/assignments/01/generate_test_data.py

69 lines
2 KiB
Python
Raw Permalink Normal View History

2023-10-09 07:51:38 +00:00
import numpy as np
2023-09-23 05:04:06 +00:00
import random
import click
2023-10-09 07:51:38 +00:00
import pathlib
2023-09-23 05:04:06 +00:00
def evaluate(w, p):
result = sum(map(lambda s: s[0] * s[1], zip(w, p)))
return result
@click.command()
2023-10-09 07:51:38 +00:00
@click.argument('data_path')
@click.argument('label_path')
@click.argument('out_path')
def generate_test_data(data_path: str, label_path: str, out_path: str):
with open(data_path, "r") as f:
desc = f.readline().strip()
rows, dimensions = map(int, desc.split())
data = np.loadtxt(f)
print("loaded data")
with open(label_path, "r") as f:
desc = f.readline().strip()
rows = int(desc)
labels = np.loadtxt(f)
print("loaded labels")
indices = list(range(rows))
random.shuffle(indices)
split_at = int(0.7 * rows)
train_indices = indices[:split_at]
test_indices = indices[split_at:]
# print("WTF?", train_indices, test_indices)
train_data = data[train_indices,:]
train_label = labels[train_indices]
test_data = data[test_indices,:]
test_label = labels[test_indices]
out_path2 = pathlib.Path(out_path)
out_path2.mkdir(exist_ok=True, parents=True)
with open(out_path2 / "train_data.txt", "w") as f:
f.write(f"{len(train_data)} {dimensions}\n")
for row in train_data:
for i, cell in enumerate(row):
if i > 0: f.write(" ")
f.write(str(cell))
f.write("\n")
with open(out_path2 / "test_data.txt", "w") as f:
f.write(f"{len(test_data)} {dimensions}\n")
for row in test_data:
for i, cell in enumerate(row):
if i > 0: f.write(" ")
f.write(str(cell))
f.write("\n")
with open(out_path2 / "train_label.txt", "w") as f:
f.write(f"{len(train_label)}\n")
for cell in train_label:
f.write(f"{cell}\n")
with open(out_path2 / "test_label.txt", "w") as f:
f.write(f"{len(test_label)}\n")
for cell in test_label:
f.write(f"{cell}\n")
2023-09-23 05:04:06 +00:00
if __name__ == '__main__':
generate_test_data()