This commit is contained in:
Michael Zhang 2021-09-06 16:03:42 -05:00
parent ac61899f6e
commit 609785d503
Signed by: michael
GPG key ID: BDA47A31A3C8EE6B
3 changed files with 49 additions and 6 deletions

View file

@ -1 +1,4 @@
.gradle .gradle
.project
.settings
build

View file

@ -2,6 +2,7 @@ package org.lenskit.mooc.nonpers.mean;
import it.unimi.dsi.fastutil.longs.Long2DoubleOpenHashMap; import it.unimi.dsi.fastutil.longs.Long2DoubleOpenHashMap;
import it.unimi.dsi.fastutil.longs.Long2IntOpenHashMap; import it.unimi.dsi.fastutil.longs.Long2IntOpenHashMap;
import it.unimi.dsi.fastutil.longs.LongSet;
import org.lenskit.baseline.MeanDamping; import org.lenskit.baseline.MeanDamping;
import org.lenskit.data.dao.DataAccessObject; import org.lenskit.data.dao.DataAccessObject;
import org.lenskit.data.ratings.Rating; import org.lenskit.data.ratings.Rating;
@ -62,7 +63,35 @@ public class DampedItemMeanModelProvider implements Provider<ItemMeanModel> {
@Override @Override
public ItemMeanModel get() { public ItemMeanModel get() {
// TODO Compute damped means // TODO Compute damped means
// TODO Remove the line below when you have finished Long2DoubleOpenHashMap means = new Long2DoubleOpenHashMap();
throw new UnsupportedOperationException("damped mean not implemented"); Long2IntOpenHashMap lens = new Long2IntOpenHashMap();
double globalMean = 0;
int globalLen = 0;
try (ObjectStream<Rating> ratings = dao.query(Rating.class).stream()) {
for (Rating r: ratings) {
// this loop will run once for each rating in the data set
means.addTo(r.getItemId(), r.getValue());
lens.addTo(r.getItemId(), 1);
globalMean += r.getValue();
globalLen += 1;
}
}
globalMean /= globalLen;
LongSet keys = means.keySet();
for (long key : keys) {
double val = means.get(key);
val = (val + damping * globalMean) / (lens.get(key) + damping);
means.put(key, val);
if (key == 2959 || key == 1203) {
logger.info("Damped mean for item {} is {}", key, val);
}
}
return new ItemMeanModel(means);
} }
} }

View file

@ -2,6 +2,7 @@ package org.lenskit.mooc.nonpers.mean;
import it.unimi.dsi.fastutil.longs.Long2DoubleOpenHashMap; import it.unimi.dsi.fastutil.longs.Long2DoubleOpenHashMap;
import it.unimi.dsi.fastutil.longs.Long2IntOpenHashMap; import it.unimi.dsi.fastutil.longs.Long2IntOpenHashMap;
import it.unimi.dsi.fastutil.longs.LongSet;
import org.lenskit.data.dao.DataAccessObject; import org.lenskit.data.dao.DataAccessObject;
import org.lenskit.data.ratings.Rating; import org.lenskit.data.ratings.Rating;
import org.lenskit.inject.Transient; import org.lenskit.inject.Transient;
@ -51,17 +52,27 @@ public class ItemMeanModelProvider implements Provider<ItemMeanModel> {
*/ */
@Override @Override
public ItemMeanModel get() { public ItemMeanModel get() {
// TODO Set up data structures for computing means Long2DoubleOpenHashMap means = new Long2DoubleOpenHashMap();
Long2IntOpenHashMap lens = new Long2IntOpenHashMap();
try (ObjectStream<Rating> ratings = dao.query(Rating.class).stream()) { try (ObjectStream<Rating> ratings = dao.query(Rating.class).stream()) {
for (Rating r: ratings) { for (Rating r: ratings) {
// this loop will run once for each rating in the data set // this loop will run once for each rating in the data set
// TODO process this rating means.addTo(r.getItemId(), r.getValue());
lens.addTo(r.getItemId(), 1);
} }
} }
Long2DoubleOpenHashMap means = new Long2DoubleOpenHashMap(); LongSet keys = means.keySet();
// TODO Finalize means to store them in the mean model for (long key : keys) {
double val = means.get(key);
val /= lens.get(key);
means.put(key, val);
if (key == 2959 || key == 1203) {
logger.info("Damped mean for item {} is {}", key, val);
}
}
logger.info("computed mean ratings for {} items", means.size()); logger.info("computed mean ratings for {} items", means.size());
return new ItemMeanModel(means); return new ItemMeanModel(means);