From 609785d50341911dd4d6ac48cc5e1bf09180eb45 Mon Sep 17 00:00:00 2001 From: Michael Zhang Date: Mon, 6 Sep 2021 16:03:42 -0500 Subject: [PATCH] means --- nonpers-assignment/.gitignore | 3 ++ .../mean/DampedItemMeanModelProvider.java | 33 +++++++++++++++++-- .../nonpers/mean/ItemMeanModelProvider.java | 19 ++++++++--- 3 files changed, 49 insertions(+), 6 deletions(-) diff --git a/nonpers-assignment/.gitignore b/nonpers-assignment/.gitignore index 08a55c0..2b4c1b4 100644 --- a/nonpers-assignment/.gitignore +++ b/nonpers-assignment/.gitignore @@ -1 +1,4 @@ .gradle +.project +.settings +build diff --git a/nonpers-assignment/src/main/java/org/lenskit/mooc/nonpers/mean/DampedItemMeanModelProvider.java b/nonpers-assignment/src/main/java/org/lenskit/mooc/nonpers/mean/DampedItemMeanModelProvider.java index afe0463..b489c5f 100644 --- a/nonpers-assignment/src/main/java/org/lenskit/mooc/nonpers/mean/DampedItemMeanModelProvider.java +++ b/nonpers-assignment/src/main/java/org/lenskit/mooc/nonpers/mean/DampedItemMeanModelProvider.java @@ -2,6 +2,7 @@ package org.lenskit.mooc.nonpers.mean; import it.unimi.dsi.fastutil.longs.Long2DoubleOpenHashMap; import it.unimi.dsi.fastutil.longs.Long2IntOpenHashMap; +import it.unimi.dsi.fastutil.longs.LongSet; import org.lenskit.baseline.MeanDamping; import org.lenskit.data.dao.DataAccessObject; import org.lenskit.data.ratings.Rating; @@ -62,7 +63,35 @@ public class DampedItemMeanModelProvider implements Provider { @Override public ItemMeanModel get() { // TODO Compute damped means - // TODO Remove the line below when you have finished - throw new UnsupportedOperationException("damped mean not implemented"); + Long2DoubleOpenHashMap means = new Long2DoubleOpenHashMap(); + Long2IntOpenHashMap lens = new Long2IntOpenHashMap(); + double globalMean = 0; + int globalLen = 0; + + try (ObjectStream ratings = dao.query(Rating.class).stream()) { + for (Rating r: ratings) { + // this loop will run once for each rating in the data set + means.addTo(r.getItemId(), r.getValue()); + lens.addTo(r.getItemId(), 1); + + globalMean += r.getValue(); + globalLen += 1; + } + } + + globalMean /= globalLen; + + LongSet keys = means.keySet(); + for (long key : keys) { + double val = means.get(key); + val = (val + damping * globalMean) / (lens.get(key) + damping); + means.put(key, val); + + if (key == 2959 || key == 1203) { + logger.info("Damped mean for item {} is {}", key, val); + } + } + + return new ItemMeanModel(means); } } diff --git a/nonpers-assignment/src/main/java/org/lenskit/mooc/nonpers/mean/ItemMeanModelProvider.java b/nonpers-assignment/src/main/java/org/lenskit/mooc/nonpers/mean/ItemMeanModelProvider.java index 62935b3..a4656c5 100644 --- a/nonpers-assignment/src/main/java/org/lenskit/mooc/nonpers/mean/ItemMeanModelProvider.java +++ b/nonpers-assignment/src/main/java/org/lenskit/mooc/nonpers/mean/ItemMeanModelProvider.java @@ -2,6 +2,7 @@ package org.lenskit.mooc.nonpers.mean; import it.unimi.dsi.fastutil.longs.Long2DoubleOpenHashMap; import it.unimi.dsi.fastutil.longs.Long2IntOpenHashMap; +import it.unimi.dsi.fastutil.longs.LongSet; import org.lenskit.data.dao.DataAccessObject; import org.lenskit.data.ratings.Rating; import org.lenskit.inject.Transient; @@ -51,17 +52,27 @@ public class ItemMeanModelProvider implements Provider { */ @Override public ItemMeanModel get() { - // TODO Set up data structures for computing means + Long2DoubleOpenHashMap means = new Long2DoubleOpenHashMap(); + Long2IntOpenHashMap lens = new Long2IntOpenHashMap(); try (ObjectStream ratings = dao.query(Rating.class).stream()) { for (Rating r: ratings) { // this loop will run once for each rating in the data set - // TODO process this rating + means.addTo(r.getItemId(), r.getValue()); + lens.addTo(r.getItemId(), 1); } } - Long2DoubleOpenHashMap means = new Long2DoubleOpenHashMap(); - // TODO Finalize means to store them in the mean model + LongSet keys = means.keySet(); + for (long key : keys) { + double val = means.get(key); + val /= lens.get(key); + means.put(key, val); + + if (key == 2959 || key == 1203) { + logger.info("Damped mean for item {} is {}", key, val); + } + } logger.info("computed mean ratings for {} items", means.size()); return new ItemMeanModel(means);