use sqlite instead of bolt

This commit is contained in:
Michael Zhang 2021-07-21 17:09:06 -05:00
parent a19ae60600
commit ea7db233aa
Signed by: michael
GPG Key ID: BDA47A31A3C8EE6B
10 changed files with 139 additions and 210 deletions

1
.gitignore vendored
View File

@ -1,4 +1,5 @@
/subscribe-bot
/test.db
/test.db-sqlite
/config.toml
/repos

12
db/beatmap.go Normal file
View File

@ -0,0 +1,12 @@
package db
import "gorm.io/gorm"
type Beatmapset struct {
gorm.Model
ID int `gorm:"primaryKey"`
Artist string
Title string
MapperID int
Mapper User `gorm:"foreignKey:MapperID;references:ID"`
}

9
db/channel.go Normal file
View File

@ -0,0 +1,9 @@
package db
import "gorm.io/gorm"
type DiscordChannel struct {
gorm.Model
ID string `gorm:"primaryKey"`
TrackedMappers []User `gorm:"many2many:tracked_mappers"`
}

9
db/config.go Normal file
View File

@ -0,0 +1,9 @@
package db
import "gorm.io/gorm"
type Config struct {
gorm.Model
Key string `gorm:"key;primaryKey"`
Value string `gorm:"value"`
}

266
db/db.go
View File

@ -8,7 +8,11 @@ package db
import (
"strconv"
"github.com/pkg/errors"
bolt "go.etcd.io/bbolt"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/clause"
"subscribe-bot/osuapi"
)
@ -20,236 +24,118 @@ var (
)
type Db struct {
*bolt.DB
api *osuapi.Osuapi
gorm *gorm.DB
api *osuapi.Osuapi
}
func OpenDb(path string, api *osuapi.Osuapi) (db *Db, err error) {
inner, err := bolt.Open(path, 0666, nil)
db = &Db{inner, api}
gorm, err := gorm.Open(sqlite.Open(path), &gorm.Config{})
if err != nil {
panic("failed to connect database")
}
// auto-migrate
gorm.AutoMigrate(&Config{})
gorm.AutoMigrate(&User{})
gorm.AutoMigrate(&Beatmapset{})
gorm.AutoMigrate(&DiscordChannel{})
db = &Db{gorm, api}
return
}
// Loop over channels that are tracking this specific mapper
func (db *Db) IterTrackingChannels(mapperId int, fn func(channelId string) error) (err error) {
err = db.DB.View(func(tx *bolt.Tx) error {
mapper := getMapper(tx, mapperId)
if mapper == nil {
return nil
}
trackers := mapper.Bucket([]byte("trackers"))
if trackers == nil {
return nil
}
c := trackers.Cursor()
for k, _ := c.First(); k != nil; k, _ = c.Next() {
channelId := string(k)
err := fn(channelId)
if err != nil {
return err
}
}
return nil
})
func (db *Db) IterTrackingChannels(mapperId int, fn func(channel DiscordChannel) error) (err error) {
var channels []DiscordChannel
db.gorm.Model(&User{ID: mapperId}).Association("TrackingChannels").Find(&channels)
for _, channel := range channels {
fn(channel)
}
return
}
// Loop over tracked mappers for this channel
func (db *Db) IterChannelTrackedMappers(channelId string, fn func(userId int) error) (err error) {
err = db.DB.View(func(tx *bolt.Tx) error {
channels := tx.Bucket(CHANNELS)
if channels == nil {
return nil
}
channel := channels.Bucket([]byte(channelId))
if channel == nil {
return nil
}
tracks := channel.Bucket([]byte("tracks"))
if tracks == nil {
return nil
}
c := tracks.Cursor()
for k, _ := c.First(); k != nil; k, _ = c.Next() {
mapperId, err := strconv.Atoi(string(k))
if err != nil {
return err
}
err = fn(mapperId)
if err != nil {
return err
}
}
return nil
})
func (db *Db) IterChannelTrackedMappers(channelId string, fn func(user User) error) (err error) {
var mappers []User
db.gorm.Model(&DiscordChannel{ID: channelId}).Association("TrackedMappers").Find(&mappers)
for _, mapper := range mappers {
fn(mapper)
}
return
}
// Loop over all tracked mappers
func (db *Db) IterAllTrackedMappers(fn func(userId int) error) (err error) {
err = db.DB.View(func(tx *bolt.Tx) error {
mappers := tx.Bucket(MAPPERS)
if mappers == nil {
return nil
}
c := mappers.Cursor()
for k, _ := c.First(); k != nil; k, _ = c.Next() {
userId, err := strconv.Atoi(string(k))
if err != nil {
return err
}
err = fn(userId)
if err != nil {
return err
}
}
return nil
})
func (db *Db) IterAllTrackedMappers(fn func(user User) error) (err error) {
var mappers []User
db.gorm.Find(&mappers)
for _, mapper := range mappers {
fn(mapper)
}
return
}
// Get a list of channels that are tracking this mapper
func (db *Db) GetMapperTrackers(userId int) (trackersList []string) {
trackersList = make([]string, 0)
db.DB.View(func(tx *bolt.Tx) error {
mapper, err := getMapperMut(tx, userId)
func (db *Db) UpdateMapperLastEvent(userId int, eventId int) (err error) {
if eventId == -1 {
var events []osuapi.Event
events, err = db.api.GetUserEvents(userId, 1, 0)
if err != nil {
return err
err = errors.Wrap(err, "couldn't get user events from API")
return
}
eventId = events[0].ID
}
trackers := mapper.Bucket([]byte("trackers"))
if trackers == nil {
return nil
}
c := trackers.Cursor()
for k, _ := c.First(); k != nil; k, _ = c.Next() {
channelId := string(k)
trackersList = append(trackersList, channelId)
}
return nil
})
return
}
// Update the latest event of a mapper to the given one
func (db *Db) UpdateMapperLatestEvent(userId int, eventId int) (err error) {
err = db.DB.Update(func(tx *bolt.Tx) error {
mapper, err := getMapperMut(tx, userId)
if err != nil {
return err
}
err = mapper.Put(LATEST_EVENT, []byte(strconv.Itoa(eventId)))
if err != nil {
return err
}
return nil
})
return
db.gorm.Model(&User{}).Where("id = ?", userId).Update("latest_event_id", eventId)
return nil
}
// Get the latest event ID of this mapper, if they have one
func (db *Db) MapperLastEvent(userId int) (has bool, id int) {
has = false
id = -1
db.DB.View(func(tx *bolt.Tx) error {
mapper := getMapper(tx, userId)
if mapper == nil {
return nil
}
lastEventId := mapper.Get(LATEST_EVENT)
if lastEventId == nil {
return nil
}
var err error
id, err = strconv.Atoi(string(lastEventId))
if err != nil {
return nil
}
has = true
return nil
})
return
var user User
db.gorm.Select("latest_event_id").First(&user)
return true, user.LatestEventID
}
// Start tracking a new mapper (if they're not already tracked)
func (db *Db) ChannelTrackMapper(channelId string, mapperId int, priority int) (err error) {
events, err := db.api.GetUserEvents(mapperId, 1, 0)
err = db.gorm.Model(&DiscordChannel{ID: channelId}).Association("TrackedMappers").Append(db.getUser(mapperId))
if err != nil {
err = errors.Wrap(err, "could not add tracking for channel "+channelId)
return
}
err = db.Batch(func(tx *bolt.Tx) error {
{
mapper, err := getMapperMut(tx, mapperId)
if err != nil {
return err
}
err = db.UpdateMapperLastEvent(mapperId, -1)
if err != nil {
err = errors.Wrap(err, "could not update mapper latest event")
return
}
if len(events) > 0 {
latestEventId := strconv.Itoa(events[0].ID)
mapper.Put(LATEST_EVENT, []byte(latestEventId))
}
trackers, err := mapper.CreateBucketIfNotExists([]byte("trackers"))
if err != nil {
return err
}
err = trackers.Put([]byte(channelId), []byte(strconv.Itoa(priority)))
if err != nil {
return err
}
}
{
channels, err := tx.CreateBucketIfNotExists(CHANNELS)
if err != nil {
return err
}
channel, err := channels.CreateBucketIfNotExists([]byte(channelId))
if err != nil {
return err
}
tracks, err := channel.CreateBucketIfNotExists([]byte("tracks"))
if err != nil {
return err
}
err = tracks.Put([]byte(strconv.Itoa(mapperId)), []byte(strconv.Itoa(priority)))
if err != nil {
return err
}
}
return nil
})
return
return nil
}
func (db *Db) Close() {
db.DB.Close()
}
func (db *Db) getUser(userId int) (user *User, err error) {
// TODO: cache user info for some time?
apiUser, err := db.api.GetUser(strconv.Itoa(userId))
if err != nil {
err = errors.Wrap(err, "could not retrieve user from the API")
return
}
user = &User{
ID: userId,
Username: apiUser.Username,
Country: apiUser.CountryCode,
}
db.gorm.Clauses(clause.OnConflict{UpdateAll: true}).Create(user)
return
}
func getMapper(tx *bolt.Tx, userId int) (mapper *bolt.Bucket) {

12
db/user.go Normal file
View File

@ -0,0 +1,12 @@
package db
import "gorm.io/gorm"
type User struct {
gorm.Model
ID int `gorm:"primaryKey"`
Username string
Country string
LatestEventID int `gorm:"latest_event_id"`
TrackingChannels []DiscordChannel `gorm:"many2many:tracked_mappers"`
}

View File

@ -305,13 +305,7 @@ func (bot *Bot) newMessageHandler(s *discordgo.Session, m *discordgo.MessageCrea
case "list":
mappers := make([]string, 0)
bot.db.IterChannelTrackedMappers(m.ChannelID, func(userId int) error {
var mapper osuapi.User
mapper, err = bot.api.GetUser(strconv.Itoa(userId))
if err != nil {
return err
}
bot.db.IterChannelTrackedMappers(m.ChannelID, func(mapper db.User) error {
mappers = append(mappers, mapper.Username)
return nil
})

2
go.mod
View File

@ -17,4 +17,6 @@ require (
github.com/pkg/errors v0.8.1
go.etcd.io/bbolt v1.3.5
golang.org/x/sync v0.0.0-20201008141435-b3e1573b7520
gorm.io/driver/sqlite v1.1.4
gorm.io/gorm v1.21.12
)

12
go.sum
View File

@ -82,6 +82,11 @@ github.com/imdario/mergo v0.3.9/go.mod h1:2EnlNZ0deacrJVfApfmtdGgDfMuh/nq6Ok1EcJ
github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 h1:BQSFePA1RWJOlocH6Fxy8MmwDt+yVQYULKfN0RoTN8A=
github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99/go.mod h1:1lJo3i6rXxKeerYnT8Nvf0QmHCRC1n8sfWVwXF2Frvo=
github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI=
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.1/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/jinzhu/now v1.1.2 h1:eVKgfIdy9b6zbWBMgFpfDPoAMifwSZagU9HmEU6zgiI=
github.com/jinzhu/now v1.1.2/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
github.com/json-iterator/go v1.1.9 h1:9yzud/Ht36ygwatGx56VwCZtlI/2AD15T1X2sjSuGns=
@ -107,6 +112,8 @@ github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hd
github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ=
github.com/mattn/go-isatty v0.0.12 h1:wuysRhFDzyxgEmMf5xjvJ2M9dZoWAXNNr5LSBS7uHXY=
github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
github.com/mattn/go-sqlite3 v1.14.5 h1:1IdxlwTNazvbKJQSxoJ5/9ECbEeaTTyeU7sEAZ5KKTQ=
github.com/mattn/go-sqlite3 v1.14.5/go.mod h1:WVKg1VTActs4Qso6iwGbiFih2UIHo0ENGwNd0Lj+XmI=
github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y=
github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0=
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
@ -191,3 +198,8 @@ gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10=
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gorm.io/driver/sqlite v1.1.4 h1:PDzwYE+sI6De2+mxAneV9Xs11+ZyKV6oxD3wDGkaNvM=
gorm.io/driver/sqlite v1.1.4/go.mod h1:mJCeTFr7+crvS+TRnWc5Z3UvwxUN1BGBLMrf5LA9DYw=
gorm.io/gorm v1.20.7/go.mod h1:0HFTzE/SqkGTzK6TlDPPQbAYCluiVvhzoA1+aVyzenw=
gorm.io/gorm v1.21.12 h1:3fQM0Eiz7jcJEhPggHEpoYnsGZqynMzverL77DV40RM=
gorm.io/gorm v1.21.12/go.mod h1:F+OptMscr0P2F2qU97WT1WimdH9GaQPoDW7AYd5i2Y0=

View File

@ -11,8 +11,8 @@ import (
func (s *Scraper) scrapePendingMaps() {
// build a list of currently tracked mappers
trackedMappers := make(map[int]int)
s.db.IterAllTrackedMappers(func(userId int) error {
trackedMappers[userId] = 1
s.db.IterAllTrackedMappers(func(mapper db.User) error {
trackedMappers[mapper.ID] = 1
return nil
})
@ -54,8 +54,8 @@ func (s *Scraper) scrapePendingMaps() {
if len(allNewMaps) > 0 {
for mapperId, newMaps := range allNewMaps {
channels := make([]string, 0)
s.db.IterTrackingChannels(mapperId, func(channelId string) error {
channels = append(channels, channelId)
s.db.IterTrackingChannels(mapperId, func(channel db.DiscordChannel) error {
channels = append(channels, channel.ID)
return nil
})
@ -67,11 +67,6 @@ func (s *Scraper) scrapePendingMaps() {
}
lastUpdateTime = newLastUpdateTime
// this rings the terminal bell when it's updated so i don't have to stare
// at a blank screen for 30 seconds waiting for the feed to update
if s.config.Debug {
fmt.Print("\a")
}
log.Println("last updated time", lastUpdateTime)
}
@ -138,11 +133,8 @@ func getNewMaps(db *db.Db, api *osuapi.Osuapi, userId int) (newMaps []osuapi.Eve
}
}
// TODO: debug
// updateLatestEvent = false
if updateLatestEvent {
err = db.UpdateMapperLatestEvent(userId, newLatestEvent)
err = db.UpdateMapperLastEvent(userId, newLatestEvent)
if err != nil {
return
}