refactor: use io.Reader instead of custom method

This commit is contained in:
Unlock Music Dev
2022-11-19 07:25:43 +08:00
parent 4365628bff
commit 67ff0c44cd
17 changed files with 420 additions and 460 deletions

View File

@@ -4,6 +4,7 @@ import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"strconv"
"strings"
@@ -12,14 +13,14 @@ import (
)
type Decoder struct {
r io.ReadSeeker
fileExt string
raw io.ReadSeeker
audio io.Reader
offset int
audioLen int
audioLen int
decodedKey []byte
cipher streamCipher
offset int
cipher common.StreamDecoder
decodedKey []byte
rawMetaExtra1 int
rawMetaExtra2 int
}
@@ -27,78 +28,79 @@ type Decoder struct {
// Read implements io.Reader, offer the decrypted audio data.
// Validate should call before Read to check if the file is valid.
func (d *Decoder) Read(p []byte) (int, error) {
n := len(p)
if d.audioLen <= d.offset {
return 0, io.EOF
} else if d.audioLen-d.offset < n {
n = d.audioLen - d.offset
n, err := d.audio.Read(p)
if n > 0 {
d.cipher.Decrypt(p[:n], d.offset)
d.offset += n
}
m, err := d.r.Read(p[:n])
if m > 0 {
d.cipher.Decrypt(p[:m], d.offset)
d.offset += m
}
return m, err
return n, err
}
func NewDecoder(r io.ReadSeeker) (*Decoder, error) {
d := &Decoder{r: r}
func NewDecoder(r io.ReadSeeker) common.Decoder {
return &Decoder{raw: r}
}
func (d *Decoder) Validate() error {
// search & derive key
err := d.searchKey()
if err != nil {
return nil, err
return err
}
// check cipher type and init decode cipher
if len(d.decodedKey) > 300 {
d.cipher, err = newRC4Cipher(d.decodedKey)
if err != nil {
return nil, err
return err
}
} else if len(d.decodedKey) != 0 {
d.cipher, err = newMapCipher(d.decodedKey)
if err != nil {
return nil, err
return err
}
} else {
d.cipher = newStaticCipher()
}
_, err = d.r.Seek(0, io.SeekStart)
if err != nil {
return nil, err
}
return d, nil
}
func (d *Decoder) Validate() error {
buf := make([]byte, 16)
if _, err := io.ReadFull(d.r, buf); err != nil {
return err
}
_, err := d.r.Seek(0, io.SeekStart)
if err != nil {
// test with first 16 bytes
if err := d.validateDecode(); err != nil {
return err
}
d.cipher.Decrypt(buf, 0)
fileExt, ok := common.SniffAll(buf)
if !ok {
return errors.New("detect file type failed")
// reset position, limit to audio, prepare for Read
if _, err := d.raw.Seek(0, io.SeekStart); err != nil {
return err
}
d.fileExt = fileExt
d.audio = io.LimitReader(d.raw, int64(d.audioLen))
return nil
}
func (d *Decoder) GetFileExt() string {
return d.fileExt
func (d *Decoder) validateDecode() error {
_, err := d.raw.Seek(0, io.SeekStart)
if err != nil {
return fmt.Errorf("qmc seek to start: %w", err)
}
buf := make([]byte, 16)
if _, err := io.ReadFull(d.raw, buf); err != nil {
return fmt.Errorf("qmc read header: %w", err)
}
d.cipher.Decrypt(buf, 0)
_, ok := common.SniffAll(buf)
if !ok {
return errors.New("qmc: detect file type failed")
}
return nil
}
func (d *Decoder) searchKey() error {
fileSizeM4, err := d.r.Seek(-4, io.SeekEnd)
fileSizeM4, err := d.raw.Seek(-4, io.SeekEnd)
if err != nil {
return err
}
buf, err := io.ReadAll(io.LimitReader(d.r, 4))
buf, err := io.ReadAll(io.LimitReader(d.raw, 4))
if err != nil {
return err
}
@@ -118,13 +120,13 @@ func (d *Decoder) searchKey() error {
}
func (d *Decoder) readRawKey(rawKeyLen int64) error {
audioLen, err := d.r.Seek(-(4 + rawKeyLen), io.SeekEnd)
audioLen, err := d.raw.Seek(-(4 + rawKeyLen), io.SeekEnd)
if err != nil {
return err
}
d.audioLen = int(audioLen)
rawKeyData, err := io.ReadAll(io.LimitReader(d.r, rawKeyLen))
rawKeyData, err := io.ReadAll(io.LimitReader(d.raw, rawKeyLen))
if err != nil {
return err
}
@@ -142,22 +144,22 @@ func (d *Decoder) readRawKey(rawKeyLen int64) error {
func (d *Decoder) readRawMetaQTag() error {
// get raw meta data len
if _, err := d.r.Seek(-8, io.SeekEnd); err != nil {
if _, err := d.raw.Seek(-8, io.SeekEnd); err != nil {
return err
}
buf, err := io.ReadAll(io.LimitReader(d.r, 4))
buf, err := io.ReadAll(io.LimitReader(d.raw, 4))
if err != nil {
return err
}
rawMetaLen := int64(binary.BigEndian.Uint32(buf))
// read raw meta data
audioLen, err := d.r.Seek(-(8 + rawMetaLen), io.SeekEnd)
audioLen, err := d.raw.Seek(-(8 + rawMetaLen), io.SeekEnd)
if err != nil {
return err
}
d.audioLen = int(audioLen)
rawMetaData, err := io.ReadAll(io.LimitReader(d.r, rawMetaLen))
rawMetaData, err := io.ReadAll(io.LimitReader(d.raw, rawMetaLen))
if err != nil {
return err
}
@@ -206,53 +208,6 @@ func init() {
"mflac", "mflac0", //QQ Music New Flac
}
for _, ext := range supportedExts {
common.RegisterDecoder(ext, false, newCompactDecoder)
common.RegisterDecoder(ext, false, NewDecoder)
}
}
type compactDecoder struct {
decoder *Decoder
createErr error
buf *bytes.Buffer
}
func newCompactDecoder(p []byte) common.Decoder {
r := bytes.NewReader(p)
d, err := NewDecoder(r)
c := compactDecoder{
decoder: d,
createErr: err,
}
return &c
}
func (c *compactDecoder) Validate() error {
if c.createErr != nil {
return c.createErr
}
return c.decoder.Validate()
}
func (c *compactDecoder) Decode() error {
if c.createErr != nil {
return c.createErr
}
c.buf = bytes.NewBuffer(nil)
_, err := io.Copy(c.buf, c.decoder)
return err
}
func (c *compactDecoder) GetAudioData() []byte {
return c.buf.Bytes()
}
func (c *compactDecoder) GetAudioExt() string {
if c.createErr != nil {
return ""
}
return c.decoder.GetFileExt()
}
func (c *compactDecoder) GetMeta() common.Meta {
return nil
}