use sqlite instead of bolt
This commit is contained in:
parent
a19ae60600
commit
ea7db233aa
10 changed files with 139 additions and 210 deletions
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -1,4 +1,5 @@
|
|||
/subscribe-bot
|
||||
/test.db
|
||||
/test.db-sqlite
|
||||
/config.toml
|
||||
/repos
|
||||
|
|
12
db/beatmap.go
Normal file
12
db/beatmap.go
Normal file
|
@ -0,0 +1,12 @@
|
|||
package db
|
||||
|
||||
import "gorm.io/gorm"
|
||||
|
||||
type Beatmapset struct {
|
||||
gorm.Model
|
||||
ID int `gorm:"primaryKey"`
|
||||
Artist string
|
||||
Title string
|
||||
MapperID int
|
||||
Mapper User `gorm:"foreignKey:MapperID;references:ID"`
|
||||
}
|
9
db/channel.go
Normal file
9
db/channel.go
Normal file
|
@ -0,0 +1,9 @@
|
|||
package db
|
||||
|
||||
import "gorm.io/gorm"
|
||||
|
||||
type DiscordChannel struct {
|
||||
gorm.Model
|
||||
ID string `gorm:"primaryKey"`
|
||||
TrackedMappers []User `gorm:"many2many:tracked_mappers"`
|
||||
}
|
9
db/config.go
Normal file
9
db/config.go
Normal file
|
@ -0,0 +1,9 @@
|
|||
package db
|
||||
|
||||
import "gorm.io/gorm"
|
||||
|
||||
type Config struct {
|
||||
gorm.Model
|
||||
Key string `gorm:"key;primaryKey"`
|
||||
Value string `gorm:"value"`
|
||||
}
|
266
db/db.go
266
db/db.go
|
@ -8,7 +8,11 @@ package db
|
|||
import (
|
||||
"strconv"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
bolt "go.etcd.io/bbolt"
|
||||
"gorm.io/driver/sqlite"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/clause"
|
||||
|
||||
"subscribe-bot/osuapi"
|
||||
)
|
||||
|
@ -20,236 +24,118 @@ var (
|
|||
)
|
||||
|
||||
type Db struct {
|
||||
*bolt.DB
|
||||
api *osuapi.Osuapi
|
||||
gorm *gorm.DB
|
||||
api *osuapi.Osuapi
|
||||
}
|
||||
|
||||
func OpenDb(path string, api *osuapi.Osuapi) (db *Db, err error) {
|
||||
inner, err := bolt.Open(path, 0666, nil)
|
||||
db = &Db{inner, api}
|
||||
gorm, err := gorm.Open(sqlite.Open(path), &gorm.Config{})
|
||||
if err != nil {
|
||||
panic("failed to connect database")
|
||||
}
|
||||
|
||||
// auto-migrate
|
||||
gorm.AutoMigrate(&Config{})
|
||||
gorm.AutoMigrate(&User{})
|
||||
gorm.AutoMigrate(&Beatmapset{})
|
||||
gorm.AutoMigrate(&DiscordChannel{})
|
||||
|
||||
db = &Db{gorm, api}
|
||||
return
|
||||
}
|
||||
|
||||
// Loop over channels that are tracking this specific mapper
|
||||
func (db *Db) IterTrackingChannels(mapperId int, fn func(channelId string) error) (err error) {
|
||||
err = db.DB.View(func(tx *bolt.Tx) error {
|
||||
mapper := getMapper(tx, mapperId)
|
||||
if mapper == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
trackers := mapper.Bucket([]byte("trackers"))
|
||||
if trackers == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
c := trackers.Cursor()
|
||||
for k, _ := c.First(); k != nil; k, _ = c.Next() {
|
||||
channelId := string(k)
|
||||
err := fn(channelId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
func (db *Db) IterTrackingChannels(mapperId int, fn func(channel DiscordChannel) error) (err error) {
|
||||
var channels []DiscordChannel
|
||||
db.gorm.Model(&User{ID: mapperId}).Association("TrackingChannels").Find(&channels)
|
||||
for _, channel := range channels {
|
||||
fn(channel)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Loop over tracked mappers for this channel
|
||||
func (db *Db) IterChannelTrackedMappers(channelId string, fn func(userId int) error) (err error) {
|
||||
err = db.DB.View(func(tx *bolt.Tx) error {
|
||||
channels := tx.Bucket(CHANNELS)
|
||||
if channels == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
channel := channels.Bucket([]byte(channelId))
|
||||
if channel == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
tracks := channel.Bucket([]byte("tracks"))
|
||||
if tracks == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
c := tracks.Cursor()
|
||||
for k, _ := c.First(); k != nil; k, _ = c.Next() {
|
||||
mapperId, err := strconv.Atoi(string(k))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = fn(mapperId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
func (db *Db) IterChannelTrackedMappers(channelId string, fn func(user User) error) (err error) {
|
||||
var mappers []User
|
||||
db.gorm.Model(&DiscordChannel{ID: channelId}).Association("TrackedMappers").Find(&mappers)
|
||||
for _, mapper := range mappers {
|
||||
fn(mapper)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Loop over all tracked mappers
|
||||
func (db *Db) IterAllTrackedMappers(fn func(userId int) error) (err error) {
|
||||
err = db.DB.View(func(tx *bolt.Tx) error {
|
||||
mappers := tx.Bucket(MAPPERS)
|
||||
if mappers == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
c := mappers.Cursor()
|
||||
for k, _ := c.First(); k != nil; k, _ = c.Next() {
|
||||
userId, err := strconv.Atoi(string(k))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = fn(userId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
func (db *Db) IterAllTrackedMappers(fn func(user User) error) (err error) {
|
||||
var mappers []User
|
||||
db.gorm.Find(&mappers)
|
||||
for _, mapper := range mappers {
|
||||
fn(mapper)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Get a list of channels that are tracking this mapper
|
||||
func (db *Db) GetMapperTrackers(userId int) (trackersList []string) {
|
||||
trackersList = make([]string, 0)
|
||||
db.DB.View(func(tx *bolt.Tx) error {
|
||||
mapper, err := getMapperMut(tx, userId)
|
||||
func (db *Db) UpdateMapperLastEvent(userId int, eventId int) (err error) {
|
||||
if eventId == -1 {
|
||||
var events []osuapi.Event
|
||||
events, err = db.api.GetUserEvents(userId, 1, 0)
|
||||
if err != nil {
|
||||
return err
|
||||
err = errors.Wrap(err, "couldn't get user events from API")
|
||||
return
|
||||
}
|
||||
eventId = events[0].ID
|
||||
}
|
||||
|
||||
trackers := mapper.Bucket([]byte("trackers"))
|
||||
if trackers == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
c := trackers.Cursor()
|
||||
for k, _ := c.First(); k != nil; k, _ = c.Next() {
|
||||
channelId := string(k)
|
||||
trackersList = append(trackersList, channelId)
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
// Update the latest event of a mapper to the given one
|
||||
func (db *Db) UpdateMapperLatestEvent(userId int, eventId int) (err error) {
|
||||
err = db.DB.Update(func(tx *bolt.Tx) error {
|
||||
mapper, err := getMapperMut(tx, userId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = mapper.Put(LATEST_EVENT, []byte(strconv.Itoa(eventId)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
return
|
||||
db.gorm.Model(&User{}).Where("id = ?", userId).Update("latest_event_id", eventId)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Get the latest event ID of this mapper, if they have one
|
||||
func (db *Db) MapperLastEvent(userId int) (has bool, id int) {
|
||||
has = false
|
||||
id = -1
|
||||
db.DB.View(func(tx *bolt.Tx) error {
|
||||
mapper := getMapper(tx, userId)
|
||||
if mapper == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
lastEventId := mapper.Get(LATEST_EVENT)
|
||||
if lastEventId == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var err error
|
||||
id, err = strconv.Atoi(string(lastEventId))
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
has = true
|
||||
return nil
|
||||
})
|
||||
|
||||
return
|
||||
var user User
|
||||
db.gorm.Select("latest_event_id").First(&user)
|
||||
return true, user.LatestEventID
|
||||
}
|
||||
|
||||
// Start tracking a new mapper (if they're not already tracked)
|
||||
func (db *Db) ChannelTrackMapper(channelId string, mapperId int, priority int) (err error) {
|
||||
events, err := db.api.GetUserEvents(mapperId, 1, 0)
|
||||
err = db.gorm.Model(&DiscordChannel{ID: channelId}).Association("TrackedMappers").Append(db.getUser(mapperId))
|
||||
if err != nil {
|
||||
err = errors.Wrap(err, "could not add tracking for channel "+channelId)
|
||||
return
|
||||
}
|
||||
|
||||
err = db.Batch(func(tx *bolt.Tx) error {
|
||||
{
|
||||
mapper, err := getMapperMut(tx, mapperId)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = db.UpdateMapperLastEvent(mapperId, -1)
|
||||
if err != nil {
|
||||
err = errors.Wrap(err, "could not update mapper latest event")
|
||||
return
|
||||
}
|
||||
|
||||
if len(events) > 0 {
|
||||
latestEventId := strconv.Itoa(events[0].ID)
|
||||
mapper.Put(LATEST_EVENT, []byte(latestEventId))
|
||||
}
|
||||
|
||||
trackers, err := mapper.CreateBucketIfNotExists([]byte("trackers"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = trackers.Put([]byte(channelId), []byte(strconv.Itoa(priority)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
{
|
||||
channels, err := tx.CreateBucketIfNotExists(CHANNELS)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
channel, err := channels.CreateBucketIfNotExists([]byte(channelId))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tracks, err := channel.CreateBucketIfNotExists([]byte("tracks"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = tracks.Put([]byte(strconv.Itoa(mapperId)), []byte(strconv.Itoa(priority)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *Db) Close() {
|
||||
db.DB.Close()
|
||||
}
|
||||
|
||||
func (db *Db) getUser(userId int) (user *User, err error) {
|
||||
// TODO: cache user info for some time?
|
||||
|
||||
apiUser, err := db.api.GetUser(strconv.Itoa(userId))
|
||||
if err != nil {
|
||||
err = errors.Wrap(err, "could not retrieve user from the API")
|
||||
return
|
||||
}
|
||||
|
||||
user = &User{
|
||||
ID: userId,
|
||||
Username: apiUser.Username,
|
||||
Country: apiUser.CountryCode,
|
||||
}
|
||||
db.gorm.Clauses(clause.OnConflict{UpdateAll: true}).Create(user)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func getMapper(tx *bolt.Tx, userId int) (mapper *bolt.Bucket) {
|
||||
|
|
12
db/user.go
Normal file
12
db/user.go
Normal file
|
@ -0,0 +1,12 @@
|
|||
package db
|
||||
|
||||
import "gorm.io/gorm"
|
||||
|
||||
type User struct {
|
||||
gorm.Model
|
||||
ID int `gorm:"primaryKey"`
|
||||
Username string
|
||||
Country string
|
||||
LatestEventID int `gorm:"latest_event_id"`
|
||||
TrackingChannels []DiscordChannel `gorm:"many2many:tracked_mappers"`
|
||||
}
|
|
@ -305,13 +305,7 @@ func (bot *Bot) newMessageHandler(s *discordgo.Session, m *discordgo.MessageCrea
|
|||
|
||||
case "list":
|
||||
mappers := make([]string, 0)
|
||||
bot.db.IterChannelTrackedMappers(m.ChannelID, func(userId int) error {
|
||||
var mapper osuapi.User
|
||||
mapper, err = bot.api.GetUser(strconv.Itoa(userId))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
bot.db.IterChannelTrackedMappers(m.ChannelID, func(mapper db.User) error {
|
||||
mappers = append(mappers, mapper.Username)
|
||||
return nil
|
||||
})
|
||||
|
|
2
go.mod
2
go.mod
|
@ -17,4 +17,6 @@ require (
|
|||
github.com/pkg/errors v0.8.1
|
||||
go.etcd.io/bbolt v1.3.5
|
||||
golang.org/x/sync v0.0.0-20201008141435-b3e1573b7520
|
||||
gorm.io/driver/sqlite v1.1.4
|
||||
gorm.io/gorm v1.21.12
|
||||
)
|
||||
|
|
12
go.sum
12
go.sum
|
@ -82,6 +82,11 @@ github.com/imdario/mergo v0.3.9/go.mod h1:2EnlNZ0deacrJVfApfmtdGgDfMuh/nq6Ok1EcJ
|
|||
github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 h1:BQSFePA1RWJOlocH6Fxy8MmwDt+yVQYULKfN0RoTN8A=
|
||||
github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99/go.mod h1:1lJo3i6rXxKeerYnT8Nvf0QmHCRC1n8sfWVwXF2Frvo=
|
||||
github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJSGM75WSRxI=
|
||||
github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
|
||||
github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
|
||||
github.com/jinzhu/now v1.1.1/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||
github.com/jinzhu/now v1.1.2 h1:eVKgfIdy9b6zbWBMgFpfDPoAMifwSZagU9HmEU6zgiI=
|
||||
github.com/jinzhu/now v1.1.2/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
|
||||
github.com/json-iterator/go v1.1.6/go.mod h1:+SdeFBvtyEkXs7REEP0seUULqWtbJapLOCVDaaPEHmU=
|
||||
github.com/json-iterator/go v1.1.7/go.mod h1:KdQUCv79m/52Kvf8AW2vK1V8akMuk1QjK/uOdHXbAo4=
|
||||
github.com/json-iterator/go v1.1.9 h1:9yzud/Ht36ygwatGx56VwCZtlI/2AD15T1X2sjSuGns=
|
||||
|
@ -107,6 +112,8 @@ github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hd
|
|||
github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ=
|
||||
github.com/mattn/go-isatty v0.0.12 h1:wuysRhFDzyxgEmMf5xjvJ2M9dZoWAXNNr5LSBS7uHXY=
|
||||
github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
|
||||
github.com/mattn/go-sqlite3 v1.14.5 h1:1IdxlwTNazvbKJQSxoJ5/9ECbEeaTTyeU7sEAZ5KKTQ=
|
||||
github.com/mattn/go-sqlite3 v1.14.5/go.mod h1:WVKg1VTActs4Qso6iwGbiFih2UIHo0ENGwNd0Lj+XmI=
|
||||
github.com/mitchellh/go-homedir v1.1.0 h1:lukF9ziXFxDFPkA1vsr5zpc1XuPDn/wFntq5mG+4E0Y=
|
||||
github.com/mitchellh/go-homedir v1.1.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0=
|
||||
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
|
||||
|
@ -191,3 +198,8 @@ gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
|||
gopkg.in/yaml.v2 v2.2.4/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10=
|
||||
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gorm.io/driver/sqlite v1.1.4 h1:PDzwYE+sI6De2+mxAneV9Xs11+ZyKV6oxD3wDGkaNvM=
|
||||
gorm.io/driver/sqlite v1.1.4/go.mod h1:mJCeTFr7+crvS+TRnWc5Z3UvwxUN1BGBLMrf5LA9DYw=
|
||||
gorm.io/gorm v1.20.7/go.mod h1:0HFTzE/SqkGTzK6TlDPPQbAYCluiVvhzoA1+aVyzenw=
|
||||
gorm.io/gorm v1.21.12 h1:3fQM0Eiz7jcJEhPggHEpoYnsGZqynMzverL77DV40RM=
|
||||
gorm.io/gorm v1.21.12/go.mod h1:F+OptMscr0P2F2qU97WT1WimdH9GaQPoDW7AYd5i2Y0=
|
||||
|
|
|
@ -11,8 +11,8 @@ import (
|
|||
func (s *Scraper) scrapePendingMaps() {
|
||||
// build a list of currently tracked mappers
|
||||
trackedMappers := make(map[int]int)
|
||||
s.db.IterAllTrackedMappers(func(userId int) error {
|
||||
trackedMappers[userId] = 1
|
||||
s.db.IterAllTrackedMappers(func(mapper db.User) error {
|
||||
trackedMappers[mapper.ID] = 1
|
||||
return nil
|
||||
})
|
||||
|
||||
|
@ -54,8 +54,8 @@ func (s *Scraper) scrapePendingMaps() {
|
|||
if len(allNewMaps) > 0 {
|
||||
for mapperId, newMaps := range allNewMaps {
|
||||
channels := make([]string, 0)
|
||||
s.db.IterTrackingChannels(mapperId, func(channelId string) error {
|
||||
channels = append(channels, channelId)
|
||||
s.db.IterTrackingChannels(mapperId, func(channel db.DiscordChannel) error {
|
||||
channels = append(channels, channel.ID)
|
||||
return nil
|
||||
})
|
||||
|
||||
|
@ -67,11 +67,6 @@ func (s *Scraper) scrapePendingMaps() {
|
|||
}
|
||||
|
||||
lastUpdateTime = newLastUpdateTime
|
||||
// this rings the terminal bell when it's updated so i don't have to stare
|
||||
// at a blank screen for 30 seconds waiting for the feed to update
|
||||
if s.config.Debug {
|
||||
fmt.Print("\a")
|
||||
}
|
||||
log.Println("last updated time", lastUpdateTime)
|
||||
}
|
||||
|
||||
|
@ -138,11 +133,8 @@ func getNewMaps(db *db.Db, api *osuapi.Osuapi, userId int) (newMaps []osuapi.Eve
|
|||
}
|
||||
}
|
||||
|
||||
// TODO: debug
|
||||
// updateLatestEvent = false
|
||||
|
||||
if updateLatestEvent {
|
||||
err = db.UpdateMapperLatestEvent(userId, newLatestEvent)
|
||||
err = db.UpdateMapperLastEvent(userId, newLatestEvent)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in a new issue