diff --git a/db/db.go b/db/db.go index 1b0dd60..4c07ced 100644 --- a/db/db.go +++ b/db/db.go @@ -15,6 +15,8 @@ import ( var ( LATEST_EVENT = []byte("latestEvent") + MAPPERS = []byte("mapper") + CHANNELS = []byte("channels") ) type Db struct { @@ -56,10 +58,47 @@ func (db *Db) IterTrackingChannels(mapperId int, fn func(channelId string) error return } -// Loop over tracked mappers -func (db *Db) IterTrackedMappers(fn func(userId int) error) (err error) { +// Loop over tracked mappers for this channel +func (db *Db) IterChannelTrackedMappers(channelId string, fn func(userId int) error) (err error) { err = db.DB.View(func(tx *bolt.Tx) error { - mappers := tx.Bucket([]byte("mapper")) + channels := tx.Bucket(CHANNELS) + if channels == nil { + return nil + } + + channel := channels.Bucket([]byte(channelId)) + if channel == nil { + return nil + } + + tracks := channel.Bucket([]byte("tracks")) + if tracks == nil { + return nil + } + + c := tracks.Cursor() + for k, _ := c.First(); k != nil; k, _ = c.Next() { + mapperId, err := strconv.Atoi(string(k)) + if err != nil { + return err + } + + err = fn(mapperId) + if err != nil { + return err + } + } + + return nil + }) + + return +} + +// Loop over all tracked mappers +func (db *Db) IterAllTrackedMappers(fn func(userId int) error) (err error) { + err = db.DB.View(func(tx *bolt.Tx) error { + mappers := tx.Bucket(MAPPERS) if mappers == nil { return nil } @@ -184,7 +223,7 @@ func (db *Db) ChannelTrackMapper(channelId string, mapperId int, priority int) ( } } { - channels, err := tx.CreateBucketIfNotExists([]byte("channels")) + channels, err := tx.CreateBucketIfNotExists(CHANNELS) if err != nil { return err } @@ -214,7 +253,7 @@ func (db *Db) Close() { } func getMapper(tx *bolt.Tx, userId int) (mapper *bolt.Bucket) { - mappers := tx.Bucket([]byte("mapper")) + mappers := tx.Bucket(MAPPERS) if mappers == nil { return nil } @@ -228,7 +267,7 @@ func getMapper(tx *bolt.Tx, userId int) (mapper *bolt.Bucket) { } func getMapperMut(tx *bolt.Tx, userId int) (mapper *bolt.Bucket, err error) { - mappers, err := tx.CreateBucketIfNotExists([]byte("mapper")) + mappers, err := tx.CreateBucketIfNotExists(MAPPERS) if err != nil { return } diff --git a/discord/bot.go b/discord/bot.go index db0acec..b1202d1 100644 --- a/discord/bot.go +++ b/discord/bot.go @@ -300,6 +300,21 @@ func (bot *Bot) newMessageHandler(s *discordgo.Session, m *discordgo.MessageCrea } bot.ChannelMessageSend(m.ChannelID, fmt.Sprintf("subscribed to %+v", mapper)) + + case "list": + mappers := make([]string, 0) + bot.db.IterChannelTrackedMappers(m.ChannelID, func(userId int) error { + var mapper osuapi.User + mapper, err = bot.api.GetUser(strconv.Itoa(userId)) + if err != nil { + return err + } + + mappers = append(mappers, mapper.Username) + return nil + }) + + bot.ChannelMessageSend(m.ChannelID, "tracking: "+strings.Join(mappers, ", ")) } return diff --git a/scrape/scrape.go b/scrape/scrape.go index ad107a4..a1cbafe 100644 --- a/scrape/scrape.go +++ b/scrape/scrape.go @@ -22,7 +22,7 @@ func RunScraper(config *config.Config, bot *discord.Bot, db *db.Db, api *osuapi. for ; true; <-Ticker.C { // build a list of currently tracked mappers trackedMappers := make(map[int]int) - db.IterTrackedMappers(func(userId int) error { + db.IterAllTrackedMappers(func(userId int) error { trackedMappers[userId] = 1 return nil })