ok i think the openmp one works now

This commit is contained in:
Michael Zhang 2023-10-09 02:51:38 -05:00
parent 750876761d
commit 29d791e1e6
16 changed files with 332 additions and 96 deletions

View file

@ -1,26 +1,42 @@
// For format details, see https://aka.ms/devcontainer.json. For config options, see the
// README at: https://github.com/devcontainers/templates/tree/main/src/docker-existing-dockerfile
{
"name": "Existing Dockerfile",
"build": {
// Sets the run context to one level up instead of the .devcontainer folder.
"context": "..",
// Update the 'dockerFile' property if you aren't using the standard 'Dockerfile' filename.
"dockerfile": "../Dockerfile"
}
"name": "Existing Dockerfile",
"build": {
// Sets the run context to one level up instead of the .devcontainer folder.
"context": "..",
// Update the 'dockerFile' property if you aren't using the standard 'Dockerfile' filename.
"dockerfile": "../Dockerfile"
},
// Features to add to the dev container. More info: https://containers.dev/features.
// "features": {},
// Features to add to the dev container. More info: https://containers.dev/features.
"features": {},
"customizations": {
"vscode": {
"extensions": [
"eamodio.gitlens",
"esbenp.prettier-vscode",
"llvm-vs-code-extensions.vscode-clangd",
"ms-azuretools.vscode-docker",
"ms-python.python",
"ms-python.vscode-pylance",
"ms-vscode.cpptools",
"ms-vscode.makefile-tools",
"rust-lang.rust-analyzer",
"tomoki1207.pdf"
]
}
}
// Use 'forwardPorts' to make a list of ports inside the container available locally.
// "forwardPorts": [],
// Use 'forwardPorts' to make a list of ports inside the container available locally.
// "forwardPorts": [],
// Uncomment the next line to run commands after the container is created.
// "postCreateCommand": "cat /etc/os-release",
// Uncomment the next line to run commands after the container is created.
// "postCreateCommand": "cat /etc/os-release",
// Configure tool-specific properties.
// "customizations": {},
// Configure tool-specific properties.
// "customizations": {},
// Uncomment to connect as an existing user other than the container default. More info: https://aka.ms/dev-containers-non-root.
// "remoteUser": "devcontainer"
// Uncomment to connect as an existing user other than the container default. More info: https://aka.ms/dev-containers-non-root.
// "remoteUser": "devcontainer"
}

View file

@ -1,5 +1,9 @@
{
"files.associations": {
"common.h": "c"
}
},
"[python]": {
"editor.defaultFormatter": "ms-python.black-formatter"
},
"python.formatting.provider": "none"
}

View file

@ -1,4 +1,19 @@
FROM ubuntu:22.04
ENV PATH="/root/.cargo/bin:${PATH}"
RUN apt update -y && apt install -y --no-install-recommends \
git make gcc valgrind
build-essential \
ca-certificates \
clangd \
curl \
direnv \
git \
libomp-dev \
python3 \
python3-pip \
valgrind \
;
RUN pip install poetry
RUN curl https://sh.rustup.rs -sSf | bash -s -- -y
RUN echo 'eval "$(direnv hook bash)"' >> /root/.bashrc

View file

@ -1,2 +1,2 @@
CompileFlags:
Add: -I/opt/homebrew/opt/libomp/include
Add: -I/opt/homebrew/opt/libomp/include -I/usr/lib/gcc/aarch64-linux-gnu/11/include

View file

@ -1,2 +1,2 @@
export BASE_PATH=$PWD
export CC=clang-17
# export CC=clang-17

View file

@ -1,13 +1,20 @@
.PHONY: all watch-openmp clean
CFLAGS := -std=c11 -fopenmp -I/opt/homebrew/opt/libomp/include -g
LDFLAGS := -std=c11 -fopenmp -L/opt/homebrew/opt/libomp/lib -g
CFLAGS := -std=c11 -fopenmp \
-I/opt/homebrew/opt/libomp/include \
-I/usr/lib/gcc/aarch64-linux-gnu/11/include \
-O3
LDFLAGS := -std=c11 -fopenmp -L/opt/homebrew/opt/libomp/lib -O3
RUST_SOURCES := $(shell find . -name "*.rs")
all: lc_openmp lc_pthreads
clean:
rm -rf lc_openmp lc_pthreads zhan4854 zhan4854.tar.gz *.o
rm -rf \
lc_openmp lc_pthreads \
zhan4854 zhan4854.tar.gz \
dataset/small \
*.o
zhan4854.tar.gz: common.c common.h lc_openmp.c lc_pthreads.c Makefile
mkdir -p zhan4854
@ -15,9 +22,6 @@ zhan4854.tar.gz: common.c common.h lc_openmp.c lc_pthreads.c Makefile
tar -czvf $@ zhan4854
rm -r zhan4854
watch-openmp:
watchexec -c clear -e Makefile,c,h 'make lc_openmp && ./lc_openmp ./dataset/small_data.csv ./dataset/small_label.csv 10 2'
lc_openmp: lc_openmp.o common.o
$(CC) $(CFLAGS) $(LDFLAGS) -o $@ $^
@ -27,5 +31,22 @@ lc_pthreads: lc_pthreads.o common.o
%.o: %.c
$(CC) $(CFLAGS) -o $@ -c $<
## Dumb debug stuff, please ignore:
dataset/small/%.txt: generate_test_data.py
python generate_test_data.py dataset/small_data.csv dataset/small_label.csv dataset/small
dataset/mnist/%.txt: generate_test_data.py
python generate_test_data.py dataset/MNIST_data.csv dataset/MNIST_label.csv dataset/mnist
watch-openmp:
watchexec -c clear -e Makefile,c,h 'make lc_openmp && ./lc_openmp ./dataset/small_data.csv ./dataset/small_label.csv 10 2'
run-openmp-small: lc_openmp dataset/small/train_data.txt
./lc_openmp dataset/small/train_data.txt dataset/small/train_label.txt 10 2 dataset/small/test_data.txt dataset/small/test_label.txt
run-openmp-mnist: lc_openmp dataset/mnist/train_data.txt
./lc_openmp dataset/mnist/train_data.txt dataset/mnist/train_label.txt 10 2 dataset/mnist/test_data.txt dataset/mnist/test_label.txt
rust: $(RUST_SOURCES)
cargo run -- ${BASE_PATH}/dataset/{small_data.csv,small_label.csv} 10 2

View file

@ -1,20 +1,68 @@
import numpy as np
import random
import click
import pathlib
def evaluate(w, p):
result = sum(map(lambda s: s[0] * s[1], zip(w, p)))
return result
@click.command()
@click.option('--dimensions', default=2, help='Number of dimensions')
@click.option('--count', default=2000, help='How many points to generate')
def generate_test_data(dimensions: int, count: int):
actual_w = [random.uniform(0.0, 1.0) for _ in range(dimensions)]
@click.argument('data_path')
@click.argument('label_path')
@click.argument('out_path')
def generate_test_data(data_path: str, label_path: str, out_path: str):
with open(data_path, "r") as f:
desc = f.readline().strip()
rows, dimensions = map(int, desc.split())
data = np.loadtxt(f)
print("loaded data")
for _ in range(count):
point = [random.uniform(0.0, 1.0) for _ in range(dimensions)]
y = evaluate(actual_w, point)
print(point, y)
with open(label_path, "r") as f:
desc = f.readline().strip()
rows = int(desc)
labels = np.loadtxt(f)
print("loaded labels")
indices = list(range(rows))
random.shuffle(indices)
split_at = int(0.7 * rows)
train_indices = indices[:split_at]
test_indices = indices[split_at:]
# print("WTF?", train_indices, test_indices)
train_data = data[train_indices,:]
train_label = labels[train_indices]
test_data = data[test_indices,:]
test_label = labels[test_indices]
out_path2 = pathlib.Path(out_path)
out_path2.mkdir(exist_ok=True, parents=True)
with open(out_path2 / "train_data.txt", "w") as f:
f.write(f"{len(train_data)} {dimensions}\n")
for row in train_data:
for i, cell in enumerate(row):
if i > 0: f.write(" ")
f.write(str(cell))
f.write("\n")
with open(out_path2 / "test_data.txt", "w") as f:
f.write(f"{len(test_data)} {dimensions}\n")
for row in test_data:
for i, cell in enumerate(row):
if i > 0: f.write(" ")
f.write(str(cell))
f.write("\n")
with open(out_path2 / "train_label.txt", "w") as f:
f.write(f"{len(train_label)}\n")
for cell in train_label:
f.write(f"{cell}\n")
with open(out_path2 / "test_label.txt", "w") as f:
f.write(f"{len(test_label)}\n")
for cell in test_label:
f.write(f"{cell}\n")
if __name__ == '__main__':
generate_test_data()

View file

@ -7,6 +7,13 @@
#include "common.h"
int main(int argc, char **argv) {
if (argc < 5) {
fprintf(stderr,
"USAGE: %s data_file label_file outer_iterations thread_count",
argv[0]);
exit(1);
}
char *data_file_name = argv[1], *label_file_name = argv[2];
int outer_iterations = atoi(argv[3]);
int thread_count = atoi(argv[4]);
@ -15,6 +22,13 @@ int main(int argc, char **argv) {
struct data *data = read_data(data_file_name);
struct labels *label = read_labels(label_file_name);
// NAN CHECK
for (int i = 0; i < data->dimensions * data->rows; i++) {
if (isnan(data->buf[i]))
printf("failed at index %d\n", i);
}
printf("Running %d iteration(s) with %d thread(s).\n", outer_iterations,
thread_count);
@ -28,11 +42,11 @@ int main(int argc, char **argv) {
#pragma omp parallel for default(shared)
for (int i = 0; i < data->dimensions; i++) {
#pragma omp parallel for default(shared)
// #pragma omp parallel for default(shared)
for (int j = 0; j < data->rows; j++) {
FLOAT x_ni_w_ni = 0;
#pragma omp parallel for default(shared) reduction(+ : x_ni_w_ni)
// #pragma omp parallel for default(shared) reduction(+ : x_ni_w_ni)
for (int i2 = 0; i2 < data->dimensions; i2++) {
if (i2 == i)
continue;
@ -40,40 +54,72 @@ int main(int argc, char **argv) {
x_ni_w_ni = data->buf[data->rows * i2 + j] * w[i2];
}
ouais[data->dimensions * i + j] = label->buf[j] - x_ni_w_ni;
ouais[data->rows * i + j] = label->buf[j] - x_ni_w_ni;
}
FLOAT numer = 0, denom = 0;
#pragma omp parallel for default(shared) reduction(+ : numer, denom)
// #pragma omp parallel for default(shared) reduction(+ : numer, denom)
for (int j = 0; j < data->rows; j++) {
FLOAT xij = data->buf[data->dimensions * i + j];
numer = xij * ouais[data->dimensions * i + j];
denom = xij * xij;
FLOAT xij = data->buf[data->rows * i + j];
numer += xij * ouais[data->rows * i + j];
denom += xij * xij;
}
new_w[i] = numer / denom;
if (denom == 0) {
new_w[i] = 0;
// printf("wtf? %f\n", numer);
} else
new_w[i] = numer / denom;
}
printf("Done.\n");
for (int idx = 0; idx < data->dimensions; idx++) {
printf("%.3f ", new_w[idx]);
}
printf("\n");
memcpy(w, new_w, sizeof(w));
double end_time = monotonic_seconds();
print_time(end_time - start_time);
printf("Done.\nw = [");
for (int idx = 0; idx < data->dimensions; idx++) {
w[idx] = new_w[idx];
printf("%.3f ", w[idx]);
}
printf("]\n");
// memcpy(w, new_w, data->dimensions * sizeof(FLOAT));
}
free(ouais);
free(w);
free(new_w);
free(data->buf);
free(label->buf);
free(data);
free(label);
// NOTE: NOT PART OF THE ASSIGNMENT
// Perform testing to ensure that the training actually works
if (argc >= 7) {
struct data *test_data = read_data(argv[5]);
struct labels *test_label = read_labels(argv[6]);
int num_correct = 0;
for (int j = 0; j < test_data->rows; j++) {
FLOAT output = 0;
for (int i = 0; i < test_data->dimensions; i++) {
output += test_data->buf[test_data->rows * i + j] * w[i];
}
// printf("expected: %f, actual: %f\n", test_label->buf[j], output);
FLOAT correct_answer = test_label->buf[j];
FLOAT incorrect_answer = -correct_answer;
if (fabs(output - correct_answer) < fabs(output - incorrect_answer))
num_correct += 1;
}
printf("num correct: %d, out of %d (%.2f%%)\n", num_correct,
test_data->rows, (100.0 * num_correct) / test_data->rows);
}
free(w);
return 0;
}

View file

@ -1,3 +1,18 @@
#include <stdio.h>
#include "common.h"
int main() { return 0; }
int main(int argc, char **argv) {
if (argc < 5) {
fprintf(stderr,
"USAGE: %s data_file label_file outer_iterations thread_count",
argv[0]);
exit(1);
}
char *data_file_name = argv[1], *label_file_name = argv[2];
int outer_iterations = atoi(argv[3]);
int thread_count = atoi(argv[4]);
return 0;
}

View file

@ -16,9 +16,12 @@ print("loaded labels")
print(data.shape)
print(labels.shape)
w = np.empty((dimensions, 1))
new_w = np.empty(w.shape)
w = np.zeros(dimensions)
print(w)
np.set_printoptions(precision=3)
for _ in range(10):
new_w = np.empty(w.shape)
for i in range(dimensions):
data_ni = np.delete(data, i, axis=1)
w_ni = np.delete(w, i)
@ -26,9 +29,10 @@ for _ in range(10):
res = data_ni @ w_ni
x_i = data[:,i]
numer = x_i.transpose() @ (labels - np.matmul(data_ni, w_ni))
numer = x_i.transpose() @ (labels - res)
denom = x_i.transpose() @ x_i
new_w[i] = numer / denom
print("new_w", np.round(new_w, 3))
w = new_w
print("w", new_w)

View file

@ -25,7 +25,48 @@ files = [
{file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"},
]
[[package]]
name = "numpy"
version = "1.26.0"
description = "Fundamental package for array computing in Python"
optional = false
python-versions = "<3.13,>=3.9"
files = [
{file = "numpy-1.26.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:f8db2f125746e44dce707dd44d4f4efeea8d7e2b43aace3f8d1f235cfa2733dd"},
{file = "numpy-1.26.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:0621f7daf973d34d18b4e4bafb210bbaf1ef5e0100b5fa750bd9cde84c7ac292"},
{file = "numpy-1.26.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:51be5f8c349fdd1a5568e72713a21f518e7d6707bcf8503b528b88d33b57dc68"},
{file = "numpy-1.26.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:767254ad364991ccfc4d81b8152912e53e103ec192d1bb4ea6b1f5a7117040be"},
{file = "numpy-1.26.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:436c8e9a4bdeeee84e3e59614d38c3dbd3235838a877af8c211cfcac8a80b8d3"},
{file = "numpy-1.26.0-cp310-cp310-win32.whl", hash = "sha256:c2e698cb0c6dda9372ea98a0344245ee65bdc1c9dd939cceed6bb91256837896"},
{file = "numpy-1.26.0-cp310-cp310-win_amd64.whl", hash = "sha256:09aaee96c2cbdea95de76ecb8a586cb687d281c881f5f17bfc0fb7f5890f6b91"},
{file = "numpy-1.26.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:637c58b468a69869258b8ae26f4a4c6ff8abffd4a8334c830ffb63e0feefe99a"},
{file = "numpy-1.26.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:306545e234503a24fe9ae95ebf84d25cba1fdc27db971aa2d9f1ab6bba19a9dd"},
{file = "numpy-1.26.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8c6adc33561bd1d46f81131d5352348350fc23df4d742bb246cdfca606ea1208"},
{file = "numpy-1.26.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e062aa24638bb5018b7841977c360d2f5917268d125c833a686b7cbabbec496c"},
{file = "numpy-1.26.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:546b7dd7e22f3c6861463bebb000646fa730e55df5ee4a0224408b5694cc6148"},
{file = "numpy-1.26.0-cp311-cp311-win32.whl", hash = "sha256:c0b45c8b65b79337dee5134d038346d30e109e9e2e9d43464a2970e5c0e93229"},
{file = "numpy-1.26.0-cp311-cp311-win_amd64.whl", hash = "sha256:eae430ecf5794cb7ae7fa3808740b015aa80747e5266153128ef055975a72b99"},
{file = "numpy-1.26.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:166b36197e9debc4e384e9c652ba60c0bacc216d0fc89e78f973a9760b503388"},
{file = "numpy-1.26.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:f042f66d0b4ae6d48e70e28d487376204d3cbf43b84c03bac57e28dac6151581"},
{file = "numpy-1.26.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e5e18e5b14a7560d8acf1c596688f4dfd19b4f2945b245a71e5af4ddb7422feb"},
{file = "numpy-1.26.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7f6bad22a791226d0a5c7c27a80a20e11cfe09ad5ef9084d4d3fc4a299cca505"},
{file = "numpy-1.26.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:4acc65dd65da28060e206c8f27a573455ed724e6179941edb19f97e58161bb69"},
{file = "numpy-1.26.0-cp312-cp312-win32.whl", hash = "sha256:bb0d9a1aaf5f1cb7967320e80690a1d7ff69f1d47ebc5a9bea013e3a21faec95"},
{file = "numpy-1.26.0-cp312-cp312-win_amd64.whl", hash = "sha256:ee84ca3c58fe48b8ddafdeb1db87388dce2c3c3f701bf447b05e4cfcc3679112"},
{file = "numpy-1.26.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:4a873a8180479bc829313e8d9798d5234dfacfc2e8a7ac188418189bb8eafbd2"},
{file = "numpy-1.26.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:914b28d3215e0c721dc75db3ad6d62f51f630cb0c277e6b3bcb39519bed10bd8"},
{file = "numpy-1.26.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c78a22e95182fb2e7874712433eaa610478a3caf86f28c621708d35fa4fd6e7f"},
{file = "numpy-1.26.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:86f737708b366c36b76e953c46ba5827d8c27b7a8c9d0f471810728e5a2fe57c"},
{file = "numpy-1.26.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:b44e6a09afc12952a7d2a58ca0a2429ee0d49a4f89d83a0a11052da696440e49"},
{file = "numpy-1.26.0-cp39-cp39-win32.whl", hash = "sha256:5671338034b820c8d58c81ad1dafc0ed5a00771a82fccc71d6438df00302094b"},
{file = "numpy-1.26.0-cp39-cp39-win_amd64.whl", hash = "sha256:020cdbee66ed46b671429c7265cf00d8ac91c046901c55684954c3958525dab2"},
{file = "numpy-1.26.0-pp39-pypy39_pp73-macosx_10_9_x86_64.whl", hash = "sha256:0792824ce2f7ea0c82ed2e4fecc29bb86bee0567a080dacaf2e0a01fe7654369"},
{file = "numpy-1.26.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7d484292eaeb3e84a51432a94f53578689ffdea3f90e10c8b203a99be5af57d8"},
{file = "numpy-1.26.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:186ba67fad3c60dbe8a3abff3b67a91351100f2661c8e2a80364ae6279720299"},
{file = "numpy-1.26.0.tar.gz", hash = "sha256:f93fc78fe8bf15afe2b8d6b6499f1c73953169fad1e9a8dd086cdff3190e7fdf"},
]
[metadata]
lock-version = "2.0"
python-versions = "^3.11"
content-hash = "628b2f91879695db7c6062f64a17355697ff1237691b4f4f3105875e3b3eb929"
python-versions = ">=3.10,<3.13"
content-hash = "bddb5c36e945042dc9b3146b4a98b8bc383ce35a00ae8ca34fa85d5641790862"

View file

@ -6,8 +6,9 @@ authors = ["Michael Zhang <mail@mzhang.io>"]
readme = "README.md"
[tool.poetry.dependencies]
python = "^3.11"
python = ">=3.10,<3.13"
click = "^8.1.7"
numpy = "^1.26.0"
[build-system]

View file

@ -0,0 +1,5 @@
#!/usr/bin/env bash
exec "$1" \
$BASE_PATH/dataset/MNIST_data.csv \
$BASE_PATH/dataset/MNIST_label.csv \
10 2

View file

@ -0,0 +1,5 @@
#!/usr/bin/env bash
exec "$1" \
$BASE_PATH/dataset/small_data.csv \
$BASE_PATH/dataset/small_label.csv \
10 2

View file

@ -1,8 +1,9 @@
use std::fmt::Debug;
use std::{marker::PhantomData, ops::Index};
pub struct Span<'a, T, F, U = usize>(&'a [T], F, PhantomData<U>);
impl<'a, T, F, U> Span<'a, T, F, U> {
pub fn new(slice: &[T], func: F) -> Self {
pub fn new(slice: &'a [T], func: F) -> Self {
Span(slice, func, PhantomData::default())
}
}
@ -10,12 +11,16 @@ impl<'a, T, F, U> Span<'a, T, F, U> {
impl<'a, T, F, U> Index<U> for Span<'a, T, F, U>
where
F: Fn(U) -> usize,
U: Debug,
{
type Output = T;
#[inline]
fn index(&self, index: U) -> &Self::Output {
let index = (self.1)(index);
&self.0[index]
// print!("transformed {index:?}");
let transformed_index = (self.1)(index);
// println!(" into {transformed_index:?}");
&self.0[transformed_index]
}
}
@ -24,7 +29,7 @@ fn test() {
let w = (0..100).collect::<Vec<_>>();
// Exclude #4
let w_i_inv: Span<_, _, usize> = Span(w, fuck::<4>, PhantomData::default());
let w_i_inv: Span<_, _, usize> = Span(&w, fuck::<4>, PhantomData::default());
for i in 0..99 {
let n = w_i_inv[i];

View file

@ -29,6 +29,14 @@ struct Data {
dimensions: usize,
buf: Vec<f64>,
}
impl Data {
#[inline]
fn get(&self, i: usize, j: usize) -> f64 {
self.buf[self.rows * i + j]
}
}
struct Label {
num_points: usize,
buf: Vec<f64>,
@ -44,21 +52,11 @@ fn main() -> Result<()> {
for _ in 0..opt.outer_iterations {
let new_w = (0..data.dimensions)
.par_bridge()
// .par_bridge()
.map(|i| {
let x_i_start = data.rows * i;
let x_i_end = data.rows * (i + 1);
let x_i = &data.buf[x_i_start..x_i_end];
let missing_i = i;
let data_ni: Span<f64, _, (usize, usize)> =
Span::new(&data.buf, |(j2, i2): (usize, usize)| {
data.buf
[data.rows * (if i2 >= missing_i { i2 + 1 } else { i2 }) + j2]
});
let w_ni: Span<f64, _, usize> = Span::new(&w, |i2: usize| {
w[if i2 >= missing_i { i2 + 1 } else { i2 }]
});
// let x_i_start = data.rows * i;
// let x_i_end = data.rows * (i + 1);
// let x_i = &data.buf[x_i_start..x_i_end];
// X = n x m
// y = n x 1
@ -69,13 +67,15 @@ fn main() -> Result<()> {
// X_ni_w_ni = n x 1
let x_ni_w_ni = (0..data.rows)
.par_bridge()
// .par_bridge()
.map(|j| {
(0..data.dimensions)
.filter(|i| *i != missing_i)
.filter(|i2| *i2 != i)
.par_bridge()
.map(|i| data_ni[(j, i)] * w_ni[i])
.reduce(|| 0.0, |a, b| a + b)
.map(|i| data.get(i, j) * w[i])
.sum::<f64>()
// label.buf[j] - result
})
.collect::<Vec<_>>();
@ -86,18 +86,29 @@ fn main() -> Result<()> {
.map(|(y, x)| y - x)
.collect::<Vec<_>>();
// combine
let numer = x_i
.par_iter()
.zip(sub.par_iter())
.map(|(x, y)| x * y)
.reduce(|| 0.0, |x, y| x + y);
let (numer, denom) = (0..data.rows)
.par_bridge()
.map(|j| {
let xij = data.get(i, j);
let num = xij * sub[j];
let den = xij * xij;
(num, den)
})
.reduce(|| (0.0, 0.0), |(n1, d1), (n2, d2)| (n1 + d1, n2 + d2));
// .unwrap();
let denom = x_i
.par_iter()
.zip(x_i.par_iter())
.map(|(a, b)| a * b)
.reduce(|| 0.0, |x, y| x + y);
// combine
// let numer = x_i
// .par_iter()
// .zip(sub.par_iter())
// .map(|(x, y)| x * y)
// .reduce(|| 0.0, |x, y| x + y);
// let denom = x_i
// .par_iter()
// .zip(x_i.par_iter())
// .map(|(a, b)| a * b)
// .reduce(|| 0.0, |x, y| x + y);
return numer / denom;
})
@ -107,7 +118,6 @@ fn main() -> Result<()> {
println!("w is {w:?}");
}
println!("Hello, world!");
Ok(())
}