means
This commit is contained in:
parent
ac61899f6e
commit
609785d503
3 changed files with 49 additions and 6 deletions
3
nonpers-assignment/.gitignore
vendored
3
nonpers-assignment/.gitignore
vendored
|
@ -1 +1,4 @@
|
|||
.gradle
|
||||
.project
|
||||
.settings
|
||||
build
|
||||
|
|
|
@ -2,6 +2,7 @@ package org.lenskit.mooc.nonpers.mean;
|
|||
|
||||
import it.unimi.dsi.fastutil.longs.Long2DoubleOpenHashMap;
|
||||
import it.unimi.dsi.fastutil.longs.Long2IntOpenHashMap;
|
||||
import it.unimi.dsi.fastutil.longs.LongSet;
|
||||
import org.lenskit.baseline.MeanDamping;
|
||||
import org.lenskit.data.dao.DataAccessObject;
|
||||
import org.lenskit.data.ratings.Rating;
|
||||
|
@ -62,7 +63,35 @@ public class DampedItemMeanModelProvider implements Provider<ItemMeanModel> {
|
|||
@Override
|
||||
public ItemMeanModel get() {
|
||||
// TODO Compute damped means
|
||||
// TODO Remove the line below when you have finished
|
||||
throw new UnsupportedOperationException("damped mean not implemented");
|
||||
Long2DoubleOpenHashMap means = new Long2DoubleOpenHashMap();
|
||||
Long2IntOpenHashMap lens = new Long2IntOpenHashMap();
|
||||
double globalMean = 0;
|
||||
int globalLen = 0;
|
||||
|
||||
try (ObjectStream<Rating> ratings = dao.query(Rating.class).stream()) {
|
||||
for (Rating r: ratings) {
|
||||
// this loop will run once for each rating in the data set
|
||||
means.addTo(r.getItemId(), r.getValue());
|
||||
lens.addTo(r.getItemId(), 1);
|
||||
|
||||
globalMean += r.getValue();
|
||||
globalLen += 1;
|
||||
}
|
||||
}
|
||||
|
||||
globalMean /= globalLen;
|
||||
|
||||
LongSet keys = means.keySet();
|
||||
for (long key : keys) {
|
||||
double val = means.get(key);
|
||||
val = (val + damping * globalMean) / (lens.get(key) + damping);
|
||||
means.put(key, val);
|
||||
|
||||
if (key == 2959 || key == 1203) {
|
||||
logger.info("Damped mean for item {} is {}", key, val);
|
||||
}
|
||||
}
|
||||
|
||||
return new ItemMeanModel(means);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,6 +2,7 @@ package org.lenskit.mooc.nonpers.mean;
|
|||
|
||||
import it.unimi.dsi.fastutil.longs.Long2DoubleOpenHashMap;
|
||||
import it.unimi.dsi.fastutil.longs.Long2IntOpenHashMap;
|
||||
import it.unimi.dsi.fastutil.longs.LongSet;
|
||||
import org.lenskit.data.dao.DataAccessObject;
|
||||
import org.lenskit.data.ratings.Rating;
|
||||
import org.lenskit.inject.Transient;
|
||||
|
@ -51,17 +52,27 @@ public class ItemMeanModelProvider implements Provider<ItemMeanModel> {
|
|||
*/
|
||||
@Override
|
||||
public ItemMeanModel get() {
|
||||
// TODO Set up data structures for computing means
|
||||
Long2DoubleOpenHashMap means = new Long2DoubleOpenHashMap();
|
||||
Long2IntOpenHashMap lens = new Long2IntOpenHashMap();
|
||||
|
||||
try (ObjectStream<Rating> ratings = dao.query(Rating.class).stream()) {
|
||||
for (Rating r: ratings) {
|
||||
// this loop will run once for each rating in the data set
|
||||
// TODO process this rating
|
||||
means.addTo(r.getItemId(), r.getValue());
|
||||
lens.addTo(r.getItemId(), 1);
|
||||
}
|
||||
}
|
||||
|
||||
Long2DoubleOpenHashMap means = new Long2DoubleOpenHashMap();
|
||||
// TODO Finalize means to store them in the mean model
|
||||
LongSet keys = means.keySet();
|
||||
for (long key : keys) {
|
||||
double val = means.get(key);
|
||||
val /= lens.get(key);
|
||||
means.put(key, val);
|
||||
|
||||
if (key == 2959 || key == 1203) {
|
||||
logger.info("Damped mean for item {} is {}", key, val);
|
||||
}
|
||||
}
|
||||
|
||||
logger.info("computed mean ratings for {} items", means.size());
|
||||
return new ItemMeanModel(means);
|
||||
|
|
Loading…
Reference in a new issue