diff --git a/.gitignore b/.gitignore index 09f1a77..c441fc8 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ /subscribe-bot /test.db +/test.db-sqlite /config.toml /repos diff --git a/db/beatmap.go b/db/beatmap.go new file mode 100644 index 0000000..ede66fb --- /dev/null +++ b/db/beatmap.go @@ -0,0 +1,12 @@ +package db + +import "gorm.io/gorm" + +type Beatmapset struct { + gorm.Model + ID int `gorm:"primaryKey"` + Artist string + Title string + MapperID int + Mapper User `gorm:"foreignKey:MapperID;references:ID"` +} diff --git a/db/channel.go b/db/channel.go new file mode 100644 index 0000000..a235c89 --- /dev/null +++ b/db/channel.go @@ -0,0 +1,9 @@ +package db + +import "gorm.io/gorm" + +type DiscordChannel struct { + gorm.Model + ID string `gorm:"primaryKey"` + TrackedMappers []User `gorm:"many2many:tracked_mappers"` +} diff --git a/db/config.go b/db/config.go new file mode 100644 index 0000000..353cdae --- /dev/null +++ b/db/config.go @@ -0,0 +1,9 @@ +package db + +import "gorm.io/gorm" + +type Config struct { + gorm.Model + Key string `gorm:"key;primaryKey"` + Value string `gorm:"value"` +} diff --git a/db/db.go b/db/db.go index 4c07ced..0481fb5 100644 --- a/db/db.go +++ b/db/db.go @@ -8,7 +8,11 @@ package db import ( "strconv" + "github.com/pkg/errors" bolt "go.etcd.io/bbolt" + "gorm.io/driver/sqlite" + "gorm.io/gorm" + "gorm.io/gorm/clause" "subscribe-bot/osuapi" ) @@ -20,236 +24,118 @@ var ( ) type Db struct { - *bolt.DB - api *osuapi.Osuapi + gorm *gorm.DB + api *osuapi.Osuapi } func OpenDb(path string, api *osuapi.Osuapi) (db *Db, err error) { - inner, err := bolt.Open(path, 0666, nil) - db = &Db{inner, api} + gorm, err := gorm.Open(sqlite.Open(path), &gorm.Config{}) + if err != nil { + panic("failed to connect database") + } + + // auto-migrate + gorm.AutoMigrate(&Config{}) + gorm.AutoMigrate(&User{}) + gorm.AutoMigrate(&Beatmapset{}) + gorm.AutoMigrate(&DiscordChannel{}) + + db = &Db{gorm, api} return } // Loop over channels that are tracking this specific mapper -func (db *Db) IterTrackingChannels(mapperId int, fn func(channelId string) error) (err error) { - err = db.DB.View(func(tx *bolt.Tx) error { - mapper := getMapper(tx, mapperId) - if mapper == nil { - return nil - } - - trackers := mapper.Bucket([]byte("trackers")) - if trackers == nil { - return nil - } - - c := trackers.Cursor() - for k, _ := c.First(); k != nil; k, _ = c.Next() { - channelId := string(k) - err := fn(channelId) - if err != nil { - return err - } - } - - return nil - }) +func (db *Db) IterTrackingChannels(mapperId int, fn func(channel DiscordChannel) error) (err error) { + var channels []DiscordChannel + db.gorm.Model(&User{ID: mapperId}).Association("TrackingChannels").Find(&channels) + for _, channel := range channels { + fn(channel) + } return } // Loop over tracked mappers for this channel -func (db *Db) IterChannelTrackedMappers(channelId string, fn func(userId int) error) (err error) { - err = db.DB.View(func(tx *bolt.Tx) error { - channels := tx.Bucket(CHANNELS) - if channels == nil { - return nil - } - - channel := channels.Bucket([]byte(channelId)) - if channel == nil { - return nil - } - - tracks := channel.Bucket([]byte("tracks")) - if tracks == nil { - return nil - } - - c := tracks.Cursor() - for k, _ := c.First(); k != nil; k, _ = c.Next() { - mapperId, err := strconv.Atoi(string(k)) - if err != nil { - return err - } - - err = fn(mapperId) - if err != nil { - return err - } - } - - return nil - }) +func (db *Db) IterChannelTrackedMappers(channelId string, fn func(user User) error) (err error) { + var mappers []User + db.gorm.Model(&DiscordChannel{ID: channelId}).Association("TrackedMappers").Find(&mappers) + for _, mapper := range mappers { + fn(mapper) + } return } // Loop over all tracked mappers -func (db *Db) IterAllTrackedMappers(fn func(userId int) error) (err error) { - err = db.DB.View(func(tx *bolt.Tx) error { - mappers := tx.Bucket(MAPPERS) - if mappers == nil { - return nil - } - - c := mappers.Cursor() - for k, _ := c.First(); k != nil; k, _ = c.Next() { - userId, err := strconv.Atoi(string(k)) - if err != nil { - return err - } - - err = fn(userId) - if err != nil { - return err - } - } - - return nil - }) +func (db *Db) IterAllTrackedMappers(fn func(user User) error) (err error) { + var mappers []User + db.gorm.Find(&mappers) + for _, mapper := range mappers { + fn(mapper) + } return } -// Get a list of channels that are tracking this mapper -func (db *Db) GetMapperTrackers(userId int) (trackersList []string) { - trackersList = make([]string, 0) - db.DB.View(func(tx *bolt.Tx) error { - mapper, err := getMapperMut(tx, userId) +func (db *Db) UpdateMapperLastEvent(userId int, eventId int) (err error) { + if eventId == -1 { + var events []osuapi.Event + events, err = db.api.GetUserEvents(userId, 1, 0) if err != nil { - return err + err = errors.Wrap(err, "couldn't get user events from API") + return } + eventId = events[0].ID + } - trackers := mapper.Bucket([]byte("trackers")) - if trackers == nil { - return nil - } - - c := trackers.Cursor() - for k, _ := c.First(); k != nil; k, _ = c.Next() { - channelId := string(k) - trackersList = append(trackersList, channelId) - } - - return nil - }) - return -} - -// Update the latest event of a mapper to the given one -func (db *Db) UpdateMapperLatestEvent(userId int, eventId int) (err error) { - err = db.DB.Update(func(tx *bolt.Tx) error { - mapper, err := getMapperMut(tx, userId) - if err != nil { - return err - } - - err = mapper.Put(LATEST_EVENT, []byte(strconv.Itoa(eventId))) - if err != nil { - return err - } - - return nil - }) - return + db.gorm.Model(&User{}).Where("id = ?", userId).Update("latest_event_id", eventId) + return nil } // Get the latest event ID of this mapper, if they have one func (db *Db) MapperLastEvent(userId int) (has bool, id int) { - has = false - id = -1 - db.DB.View(func(tx *bolt.Tx) error { - mapper := getMapper(tx, userId) - if mapper == nil { - return nil - } - - lastEventId := mapper.Get(LATEST_EVENT) - if lastEventId == nil { - return nil - } - - var err error - id, err = strconv.Atoi(string(lastEventId)) - if err != nil { - return nil - } - - has = true - return nil - }) - - return + var user User + db.gorm.Select("latest_event_id").First(&user) + return true, user.LatestEventID } // Start tracking a new mapper (if they're not already tracked) func (db *Db) ChannelTrackMapper(channelId string, mapperId int, priority int) (err error) { - events, err := db.api.GetUserEvents(mapperId, 1, 0) + err = db.gorm.Model(&DiscordChannel{ID: channelId}).Association("TrackedMappers").Append(db.getUser(mapperId)) if err != nil { + err = errors.Wrap(err, "could not add tracking for channel "+channelId) return } - err = db.Batch(func(tx *bolt.Tx) error { - { - mapper, err := getMapperMut(tx, mapperId) - if err != nil { - return err - } + err = db.UpdateMapperLastEvent(mapperId, -1) + if err != nil { + err = errors.Wrap(err, "could not update mapper latest event") + return + } - if len(events) > 0 { - latestEventId := strconv.Itoa(events[0].ID) - mapper.Put(LATEST_EVENT, []byte(latestEventId)) - } - - trackers, err := mapper.CreateBucketIfNotExists([]byte("trackers")) - if err != nil { - return err - } - - err = trackers.Put([]byte(channelId), []byte(strconv.Itoa(priority))) - if err != nil { - return err - } - } - { - channels, err := tx.CreateBucketIfNotExists(CHANNELS) - if err != nil { - return err - } - - channel, err := channels.CreateBucketIfNotExists([]byte(channelId)) - if err != nil { - return err - } - - tracks, err := channel.CreateBucketIfNotExists([]byte("tracks")) - if err != nil { - return err - } - - err = tracks.Put([]byte(strconv.Itoa(mapperId)), []byte(strconv.Itoa(priority))) - if err != nil { - return err - } - } - return nil - }) - return + return nil } func (db *Db) Close() { - db.DB.Close() +} + +func (db *Db) getUser(userId int) (user *User, err error) { + // TODO: cache user info for some time? + + apiUser, err := db.api.GetUser(strconv.Itoa(userId)) + if err != nil { + err = errors.Wrap(err, "could not retrieve user from the API") + return + } + + user = &User{ + ID: userId, + Username: apiUser.Username, + Country: apiUser.CountryCode, + } + db.gorm.Clauses(clause.OnConflict{UpdateAll: true}).Create(user) + + return } func getMapper(tx *bolt.Tx, userId int) (mapper *bolt.Bucket) { diff --git a/db/user.go b/db/user.go new file mode 100644 index 0000000..ea46a73 --- /dev/null +++ b/db/user.go @@ -0,0 +1,12 @@ +package db + +import "gorm.io/gorm" + +type User struct { + gorm.Model + ID int `gorm:"primaryKey"` + Username string + Country string + LatestEventID int `gorm:"latest_event_id"` + TrackingChannels []DiscordChannel `gorm:"many2many:tracked_mappers"` +} diff --git a/discord/bot.go b/discord/bot.go index 5075e98..d8dfda2 100644 --- a/discord/bot.go +++ b/discord/bot.go @@ -305,13 +305,7 @@ func (bot *Bot) newMessageHandler(s *discordgo.Session, m *discordgo.MessageCrea case "list": mappers := make([]string, 0) - bot.db.IterChannelTrackedMappers(m.ChannelID, func(userId int) error { - var mapper osuapi.User - mapper, err = bot.api.GetUser(strconv.Itoa(userId)) - if err != nil { - return err - } - + bot.db.IterChannelTrackedMappers(m.ChannelID, func(mapper db.User) error { mappers = append(mappers, mapper.Username) return nil }) diff --git a/go.mod b/go.mod index 36b0429..017c385 100644 --- a/go.mod +++ b/go.mod @@ -17,4 +17,6 @@ require ( github.com/pkg/errors v0.8.1 go.etcd.io/bbolt v1.3.5 golang.org/x/sync v0.0.0-20201008141435-b3e1573b7520 + gorm.io/driver/sqlite v1.1.4 + gorm.io/gorm v1.21.12 ) diff --git a/go.sum b/go.sum index 0d4de69..fb0c528 100644 --- a/go.sum +++ b/go.sum @@ -82,6 +82,11 @@ github.com/imdario/mergo v0.3.9/go.mod h1:2EnlNZ0deacrJVfApfmtdGgDfMuh/nq6Ok1EcJ github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 h1:BQSFePA1RWJOlocH6Fxy8MmwDt+yVQYULKfN0RoTN8A= github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99/go.mod h1:1lJo3i6rXxKeerYnT8Nvf0QmHCRC1n8sfWVwXF2Frvo= github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/now v1.1.1/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= +github.com/jinzhu/now v1.1.2 h1:eVKgfIdy9b6zbWBMgFpfDPoAMifwSZagU9HmEU6zgiI= +github.com/jinzhu/now v1.1.2/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU= github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4= github.com/json-iterator/go v1.1.9 h1:9yzud/Ht36ygwatGx56VwCZtlI/2AD15T1X2sjSuGns= @@ -107,6 +112,8 @@ github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hd github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ= github.com/mattn/go-isatty v0.0.12 h1:wuysRhFDzyxgEmMf5xjvJ2M9dZoWAXNNr5LSBS7uHXY= github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= +github.com/mattn/go-sqlite3 v1.14.5 h1:1IdxlwTNazvbKJQSxoJ5/9ECbEeaTTyeU7sEAZ5KKTQ= +github.com/mattn/go-sqlite3 v1.14.5/go.mod h1:WVKg1VTActs4Qso6iwGbiFih2UIHo0ENGwNd0Lj+XmI= github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y= github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= @@ -191,3 +198,8 @@ gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gorm.io/driver/sqlite v1.1.4 h1:PDzwYE+sI6De2+mxAneV9Xs11+ZyKV6oxD3wDGkaNvM= +gorm.io/driver/sqlite v1.1.4/go.mod h1:mJCeTFr7+crvS+TRnWc5Z3UvwxUN1BGBLMrf5LA9DYw= +gorm.io/gorm v1.20.7/go.mod h1:0HFTzE/SqkGTzK6TlDPPQbAYCluiVvhzoA1+aVyzenw= +gorm.io/gorm v1.21.12 h1:3fQM0Eiz7jcJEhPggHEpoYnsGZqynMzverL77DV40RM= +gorm.io/gorm v1.21.12/go.mod h1:F+OptMscr0P2F2qU97WT1WimdH9GaQPoDW7AYd5i2Y0= diff --git a/scrape/pending.go b/scrape/pending.go index 3ff895a..7b4bc36 100644 --- a/scrape/pending.go +++ b/scrape/pending.go @@ -11,8 +11,8 @@ import ( func (s *Scraper) scrapePendingMaps() { // build a list of currently tracked mappers trackedMappers := make(map[int]int) - s.db.IterAllTrackedMappers(func(userId int) error { - trackedMappers[userId] = 1 + s.db.IterAllTrackedMappers(func(mapper db.User) error { + trackedMappers[mapper.ID] = 1 return nil }) @@ -54,8 +54,8 @@ func (s *Scraper) scrapePendingMaps() { if len(allNewMaps) > 0 { for mapperId, newMaps := range allNewMaps { channels := make([]string, 0) - s.db.IterTrackingChannels(mapperId, func(channelId string) error { - channels = append(channels, channelId) + s.db.IterTrackingChannels(mapperId, func(channel db.DiscordChannel) error { + channels = append(channels, channel.ID) return nil }) @@ -67,11 +67,6 @@ func (s *Scraper) scrapePendingMaps() { } lastUpdateTime = newLastUpdateTime - // this rings the terminal bell when it's updated so i don't have to stare - // at a blank screen for 30 seconds waiting for the feed to update - if s.config.Debug { - fmt.Print("\a") - } log.Println("last updated time", lastUpdateTime) } @@ -138,11 +133,8 @@ func getNewMaps(db *db.Db, api *osuapi.Osuapi, userId int) (newMaps []osuapi.Eve } } - // TODO: debug - // updateLatestEvent = false - if updateLatestEvent { - err = db.UpdateMapperLatestEvent(userId, newLatestEvent) + err = db.UpdateMapperLastEvent(userId, newLatestEvent) if err != nil { return }