diff --git a/.gitignore b/.gitignore index 14bde69..71976a5 100644 --- a/.gitignore +++ b/.gitignore @@ -6,4 +6,5 @@ examples/*.js dist/ *.pkl coverage/ -__pycache__ \ No newline at end of file +__pycache__ +.idea diff --git a/examples/index.ts b/examples/index.ts index 9544744..a6853d4 100644 --- a/examples/index.ts +++ b/examples/index.ts @@ -1,6 +1,6 @@ import fs from 'node:fs/promises'; import path from 'node:path'; -import { Parser, NameRegistry } from '../'; +import { Parser, NameRegistry, Pickler } from '../src'; class Document extends Map {} @@ -21,8 +21,17 @@ async function unpickle(fname: string) { return parser.parse(buffer); } +async function pickle(obj: unknown, fname: string) { + const pickler = new Pickler({ + protocol: 2, + }); + const buffer = pickler.dump(obj); + await fs.writeFile(fname, buffer); +} + const obj = await unpickle('wiki.pkl'); console.log(obj); +await pickle(obj, 'wiki-processed.pkl'); // const codePoints = Array.from(obj) // .map((v) => v.codePointAt(0).toString(16)) // .map((hex) => '\\u' + hex.padStart(4, 0) + '') diff --git a/src/index.ts b/src/index.ts index 0ea0cd4..35af84b 100644 --- a/src/index.ts +++ b/src/index.ts @@ -1,3 +1,5 @@ export { Parser } from './parser'; export { BufferReader } from './reader'; export { NameRegistry } from './nameRegistry'; +export { Pickler } from './pickler'; +export { BufferWriter } from './writer'; diff --git a/src/parser.ts b/src/parser.ts index f2747af..1df1518 100644 --- a/src/parser.ts +++ b/src/parser.ts @@ -297,7 +297,7 @@ export class Parser { stack = metastack.pop(); const dict = this._dictionaryProvider.create(); for (let i = 0; i < items.length; i += 2) { - dict[items[i]] = items[i + 1]; + this._dictionaryProvider.setMethod(dict, items[i], items[i + 1]); } stack.push(dict); break; diff --git a/src/pickler.ts b/src/pickler.ts new file mode 100644 index 0000000..9741b22 --- /dev/null +++ b/src/pickler.ts @@ -0,0 +1,264 @@ +import { OP } from './opcode'; +import { IWriter } from './writer'; +import { BufferWriter } from './writer'; + +const DefaultOptions: PicklerOptions = { + protocol: 5, +}; + +// Constants for number ranges +const ONE_BYTE_LIMIT = 2 ** (1 * 8) - 1; +const TWO_BYTE_LIMIT = 2 ** (2 * 8) - 1; +const FOUR_BYTE_LIMIT = 2 ** (4 * 8) - 1; +const LONG1_LIMIT_BYTES = 255; + +/** + * Serializes a JavaScript object into a Python pickle format byte stream. + */ +export class Pickler { + readonly #options: PicklerOptions; + readonly #writer: IWriter; + readonly #memo: Map; + + public constructor(options?: Partial) { + this.#options = { ...DefaultOptions, ...options }; + this.#writer = new BufferWriter(); + this.#memo = new Map(); + } + + public dump(obj: unknown): Uint8Array { + this.writeProto(); + this.write(obj); + this.#writer.byte(OP.STOP); + return this.#writer.getBuffer(); + } + + private writeProto(): void { + const protocol = this.#options.protocol; + if (protocol < 0 || protocol > 5) { + throw new Error(`Invalid protocol version: ${protocol}`); + } + + if (protocol >= 2) { + this.#writer.byte(OP.PROTO).byte(protocol); + } + } + + private write(obj: unknown): void { + // Check memo first for any non-primitive object + if ((typeof obj === 'object' && obj !== null) || typeof obj === 'function') { + const memoId = this.#memo.get(obj); + if (memoId !== undefined) { + this.writeGet(memoId); + return; + } + } + + if (obj === null) { + this.#writer.byte(OP.NONE); + return; + } + + if (Array.isArray(obj)) { + this.writeList(obj); + return; + } + + if (obj instanceof Map) { + this.writeDict(obj); + return; + } + + switch (typeof obj) { + case 'boolean': + if (this.#options.protocol >= 2) { + this.#writer.byte(obj ? OP.NEWTRUE : OP.NEWFALSE); + } else { + this.#writer.byte(OP.INT).line(obj ? '01' : '00'); + } + return; + + case 'number': + if (Number.isInteger(obj)) { + if (obj >= 0) { + if (obj <= ONE_BYTE_LIMIT) { + // OP.BININT1: 1-byte unsigned integer + this.#writer.byte(OP.BININT1).byte(obj); + return; + } + if (obj <= TWO_BYTE_LIMIT) { + // OP.BININT2: 2-byte unsigned integer + this.#writer.byte(OP.BININT2).uint16(obj); + return; + } + } + this.#writer.byte(OP.BININT).int32(obj); + } else { + this.#writer.byte(OP.BINFLOAT).float64(obj); + } + return; + + case 'bigint': + this.writeBigInt(obj); + return; + + case 'string': { + const data = this.#writer.encodeUtf8(obj); // Let's add a helper to writer + const len = data.length; + + if (this.#options.protocol >= 4) { + if (len <= ONE_BYTE_LIMIT) { + this.#writer.byte(OP.SHORT_BINUNICODE).byte(len).bytes(data); + } else if (len <= FOUR_BYTE_LIMIT) { + this.#writer.byte(OP.BINUNICODE).uint32(len).bytes(data); + } else { + this.#writer.byte(OP.BINUNICODE8).uint64(BigInt(len)).bytes(data); + } + } else { + // Older protocols (0-3) only have BINUNICODE + if (len <= FOUR_BYTE_LIMIT) { + this.#writer.byte(OP.BINUNICODE).uint32(len).bytes(data); + } else { + throw new Error(`String too long for protocol ${this.#options.protocol}`); + } + } + return; + } + + case 'object': + if (Object.getPrototypeOf(obj) === Object.prototype) { + this.writeDict(obj); + return; + } + break; // Fall through to throw an error for other object types + } + + throw new Error(`Unsupported object type: ${typeof obj}`); + } + + private bigintToLittleEndianBytes(value: bigint): Uint8Array { + if (value < 0) { + throw new Error('Pickling negative BigInts is not supported yet.'); + } + if (value === 0n) { + return new Uint8Array(0); + } + + let hex = value.toString(16); + if (hex.length % 2) { + hex = '0' + hex; + } + + const len = hex.length / 2; + const u8 = new Uint8Array(len); + + for (let i = 0, j = 0; i < len; i++) { + j = (len - 1 - i) * 2; // Read from the end of the hex string (little-endian) + u8[i] = parseInt(hex.substring(j, j + 2), 16); + } + + return u8; + } + + private writeBigInt(value: bigint): void { + if (value >= 0 && value < 1n << 31n) { + const num = Number(value); + if (num <= ONE_BYTE_LIMIT) { + this.#writer.byte(OP.BININT1).byte(num); + return; + } + if (num <= TWO_BYTE_LIMIT) { + this.#writer.byte(OP.BININT2).uint16(num); + return; + } + this.#writer.byte(OP.BININT).int32(num); + return; + } + + // For larger numbers, use LONG1 or LONG4 + const bytes = this.bigintToLittleEndianBytes(value); + + if (bytes.length < LONG1_LIMIT_BYTES) { + this.#writer.byte(OP.LONG1); + this.#writer.byte(bytes.length); + this.#writer.bytes(bytes); + } else { + this.#writer.byte(OP.LONG4); + this.#writer.uint32(bytes.length); + this.#writer.bytes(bytes); + } + } + + private writeList(list: unknown[]): void { + if (this.#options.protocol >= 1) { + this.#writer.byte(OP.EMPTY_LIST); + this.memoize(list); // ВСЕГДА мемоизируем объект списка + if (list.length > 0) { + this.#writer.byte(OP.MARK); + for (const item of list) { + this.write(item); + } + this.#writer.byte(OP.APPENDS); + } + } else { + this.memoize(list); + this.#writer.byte(OP.MARK); + for (const item of list) { + this.write(item); + } + this.#writer.byte(OP.LIST); + } + } + + private memoize(obj: unknown): void { + const id = this.#memo.size; + this.#memo.set(obj, id); + + // Choose the appropriate PUT opcode based on protocol and id size + if (this.#options.protocol >= 4) { + this.#writer.byte(OP.MEMOIZE); + } else if (id <= ONE_BYTE_LIMIT) { + this.#writer.byte(OP.BINPUT).byte(id); + } else { + this.#writer.byte(OP.LONG_BINPUT).uint32(id); + } + } + + private writeGet(id: number): void { + // Choose the appropriate GET opcode based on id size + if (id <= ONE_BYTE_LIMIT) { + this.#writer.byte(OP.BINGET).byte(id); + } else { + this.#writer.byte(OP.LONG_BINGET).uint32(id); + } + } + + private writeDict(dict: object | Map): void { + const iterableEntries = [...(dict instanceof Map ? dict.entries() : Object.entries(dict))]; + + if (this.#options.protocol >= 1) { + this.#writer.byte(OP.EMPTY_DICT); + this.memoize(dict); + if (iterableEntries.length > 0) { + this.#writer.byte(OP.MARK); + for (const [key, value] of iterableEntries) { + this.write(key); + this.write(value); + } + this.#writer.byte(OP.SETITEMS); + } + } else { + this.memoize(dict); + this.#writer.byte(OP.MARK); + for (const [key, value] of iterableEntries) { + this.write(key); + this.write(value); + } + this.#writer.byte(OP.DICT); + } + } +} + +export interface PicklerOptions { + protocol: number; +} diff --git a/src/reader.ts b/src/reader.ts index 9b24995..d30cda3 100644 --- a/src/reader.ts +++ b/src/reader.ts @@ -1,4 +1,6 @@ -type Encoding = 'ascii' | 'utf-8'; +import { BITS_PER_BYTE, Sizes } from './sizes'; + +export type Encoding = 'ascii' | 'utf-8'; export interface IReader { byte(): number; @@ -38,7 +40,7 @@ export class BufferReader implements IReader { byte() { const position = this._position; - this.skip(1); + this.skip(Sizes.Byte); return this._dataView.getUint8(position); } @@ -50,30 +52,30 @@ export class BufferReader implements IReader { uint16() { const position = this.position; - this.skip(2); + this.skip(Sizes.UInt16); return this._dataView.getUint16(position, true); } int32() { const position = this.position; - this.skip(4); + this.skip(Sizes.Int32); return this._dataView.getInt32(position, true); } uint32() { const position = this.position; - this.skip(4); + this.skip(Sizes.UInt32); return this._dataView.getUint32(position, true); } uint64() { const position = this.position; - this.skip(8); + this.skip(Sizes.UInt64); // split 64-bit number into two 32-bit parts const left = this._dataView.getUint32(position, true); - const right = this._dataView.getUint32(position + 4, true); + const right = this._dataView.getUint32(position + Sizes.UInt32, true); // combine the two 32-bit values - const number = left + 2 ** 32 * right; + const number = left + 2 ** (Sizes.UInt32 * BITS_PER_BYTE) * right; if (!Number.isSafeInteger(number)) { console.warn(number, 'exceeds MAX_SAFE_INTEGER. Precision may be lost'); } @@ -89,7 +91,7 @@ export class BufferReader implements IReader { float64() { const position = this.position; - this.skip(8); + this.skip(Sizes.Float64); return this._dataView.getFloat64(position, false); } @@ -116,7 +118,7 @@ export class BufferReader implements IReader { } const size = index - this._position; const text = this.string(size, 'ascii'); - this.skip(1); + this.skip(Sizes.Byte); return text; } @@ -126,11 +128,11 @@ export class BufferReader implements IReader { } export function readUint64(data: Uint8Array | Int8Array | Uint8ClampedArray) { - if (data.length > 8) { + if (data.length > Sizes.UInt64) { throw new Error('Value too large to unpickling'); } // Padding to 8 bytes - const buffer = new ArrayBuffer(8); + const buffer = new ArrayBuffer(Sizes.UInt64); const uint8 = new Uint8Array(buffer); uint8.set(data); const subReader = new BufferReader(uint8); @@ -142,7 +144,7 @@ export function readUint64WithBigInt(data: Uint8Array | Int8Array | Uint8Clamped let fixedLength = 0; let partCount = 0; while (fixedLength < data.length) { - fixedLength += 4; + fixedLength += Sizes.UInt32; partCount += 1; } const buffer = new ArrayBuffer(fixedLength); @@ -151,8 +153,8 @@ export function readUint64WithBigInt(data: Uint8Array | Int8Array | Uint8Clamped const view = new DataView(buffer, 0, fixedLength); let number = BigInt(0); for (let partIndex = 0; partIndex < partCount; partIndex++) { - const part = BigInt(view.getUint32(partIndex * 4, true)); - number |= part << BigInt(partIndex * 32); + const part = BigInt(view.getUint32(partIndex * Sizes.UInt32, true)); + number |= part << BigInt(partIndex * Sizes.UInt32 * BITS_PER_BYTE); } return number; } diff --git a/src/sizes.ts b/src/sizes.ts new file mode 100644 index 0000000..e313faf --- /dev/null +++ b/src/sizes.ts @@ -0,0 +1,10 @@ +export const enum Sizes { + Byte = 1, + UInt16 = 2, + Int32 = 4, + UInt32 = 4, + UInt64 = 8, + Float64 = 8, +} + +export const BITS_PER_BYTE = 8; diff --git a/src/writer.ts b/src/writer.ts new file mode 100644 index 0000000..e51c1f5 --- /dev/null +++ b/src/writer.ts @@ -0,0 +1,114 @@ +import { Encoding } from './reader'; +import { Sizes } from './sizes'; + +export interface IWriter { + byte(value: number): IWriter; + bytes(data: Uint8Array): IWriter; + uint16(value: number): IWriter; + int32(value: number): IWriter; + uint32(value: number): IWriter; + uint64(value: number | bigint): IWriter; + float64(value: number): IWriter; + line(text: string): IWriter; + string(text: string, encoding: Encoding): IWriter; + + /** Returns the complete buffer containing all written data */ + getBuffer(): Uint8Array; + encodeUtf8(text: string): Uint8Array; +} + +/** + * Class for writing data to a dynamically growing buffer. + * The internal buffer automatically expands as more data is written. + */ +export class BufferWriter implements IWriter { + #buffer: Uint8Array; + #dataView: DataView; + #position: number; + + readonly #utf8Encoder = new TextEncoder(); + + public constructor(initialCapacity = 1024) { + this.#buffer = new Uint8Array(initialCapacity); + this.#dataView = new DataView(this.#buffer.buffer); + this.#position = 0; + } + + public byte(value: number): this { + return this.write(Sizes.Byte, (pos) => this.#dataView.setUint8(pos, value)); + } + + public bytes(data: Uint8Array): this { + return this.write(data.length, (pos) => this.#buffer.set(data, pos)); + } + + public uint16(value: number): this { + return this.write(Sizes.UInt16, (pos) => this.#dataView.setUint16(pos, value, true)); + } + + public int32(value: number): this { + return this.write(Sizes.Int32, (pos) => this.#dataView.setInt32(pos, value, true)); + } + + public uint32(value: number): this { + return this.write(Sizes.UInt32, (pos) => this.#dataView.setUint32(pos, value, true)); + } + + public uint64(value: number | bigint): this { + return this.write(Sizes.UInt64, (pos) => this.#dataView.setBigUint64(pos, BigInt(value), true)); + } + + public float64(value: number): this { + return this.write(Sizes.Float64, (pos) => this.#dataView.setFloat64(pos, value, false)); + } + + public line(text: string): this { + this.string(text, 'ascii'); + this.byte(0x0a); // LF + return this; + } + + public string(text: string, encoding: Encoding): this { + if (encoding === 'utf-8') { + const data = this.#utf8Encoder.encode(text); + return this.bytes(data); + } else { + // 'ascii' + return this.write(text.length, (pos) => { + for (let i = 0; i < text.length; i++) { + this.#dataView.setUint8(pos + i, text.charCodeAt(i) & 0xff); + } + }); + } + } + + public getBuffer(): Uint8Array { + return this.#buffer.subarray(0, this.#position); + } + + public encodeUtf8(text: string): Uint8Array { + return this.#utf8Encoder.encode(text); + } + + private ensureCapacity(requiredBytes: number): void { + if (this.#position + requiredBytes < this.#buffer.byteLength) return; + const newCapacity = Math.max(this.#buffer.byteLength * 2, this.#position + requiredBytes); + const newBuffer = new Uint8Array(newCapacity); + newBuffer.set(this.#buffer); // Copy old data + this.#buffer = newBuffer; + this.#dataView = new DataView(this.#buffer.buffer); + } + + /** + * Ensures capacity, executes the write function and advances the position. + * @param size - The number of bytes to write. + * @param writeFn - The function that performs the actual write operation at a given position. + * @returns The writer instance for chaining. + */ + private write(size: number, writeFn: (position: number) => void): this { + this.ensureCapacity(size); + writeFn(this.#position); + this.#position += size; + return this; + } +} diff --git a/test/integration/_caller.ts b/test/integration/_caller.ts index 57d8442..38eb5ce 100644 --- a/test/integration/_caller.ts +++ b/test/integration/_caller.ts @@ -18,3 +18,31 @@ export async function caller(file: string, func: string, protocol: PROTOCOL = '5 }); }); } + +export async function pythonUnpickle(data: Uint8Array): Promise { + const python = spawn(DefaultPythonPath, [path.join(__dirname, '_unpickler.py')]); + + return new Promise((resolve, reject) => { + let stdout = ''; + let stderr = ''; + + python.stdout.on('data', (data) => { + stdout += data.toString(); + }); + + python.stderr.on('data', (data) => { + stderr += data.toString(); + }); + + python.on('close', (code) => { + if (code === 0) { + resolve(stdout.trim()); + } else { + reject(new Error(`Python unpickler exited with code ${code}: ${stderr}`)); + } + }); + + python.stdin.write(data); + python.stdin.end(); + }); +} diff --git a/test/integration/_unpickler.py b/test/integration/_unpickler.py new file mode 100644 index 0000000..dcace38 --- /dev/null +++ b/test/integration/_unpickler.py @@ -0,0 +1,12 @@ +import sys +import pickle + +if __name__ == "__main__": + pickled_data = sys.stdin.buffer.read() + + try: + obj = pickle.loads(pickled_data) + print(repr(obj)) + except Exception as e: + print(f"Error unpickling data: {e}", file=sys.stderr) + sys.exit(1) diff --git a/test/integration/intergration.test.ts b/test/integration/intergration.test.ts index 2f99f70..011df5c 100644 --- a/test/integration/intergration.test.ts +++ b/test/integration/intergration.test.ts @@ -1,254 +1,299 @@ import { Parser } from '../../src/parser'; import { basic } from './basic'; -import { PROTOCOL, PROTOCOLS, caller } from './_caller'; +import { PROTOCOL, PROTOCOLS, caller, pythonUnpickle } from './_caller'; import { NameRegistry } from '../../src/nameRegistry'; +import { Pickler } from '../../src/pickler'; -describe('basic with version', () => { - it.each( - Object.keys(basic).reduce((a: Array<[string, PROTOCOL]>, c) => { - PROTOCOLS.forEach((p) => { - a.push([c, p]); - }); - return a; - }, []), - )('correctly unpickled (%s) with protocol %s', async (func, protocol) => { - const data = await caller('basic', func, protocol); - const expected = basic[func](); - const obj = new Parser().parse(data); - expect(obj).toStrictEqual(expected); +describe('Parser', () => { + describe('basic with version', () => { + it.each( + Object.keys(basic).reduce((a: Array<[string, PROTOCOL]>, c) => { + PROTOCOLS.forEach((p) => { + a.push([c, p]); + }); + return a; + }, []), + )('correctly unpickled (%s) with protocol %s', async (func, protocol) => { + const data = await caller('basic', func, protocol); + const expected = basic[func](); + const obj = new Parser().parse(data); + expect(obj).toStrictEqual(expected); + }); }); -}); -describe('klass', () => { - it('correctly unpickl class', async () => { - const expected = { - array: [1, true, false, null, 4294967295], - fruits: ['apple', 'banana', 'cherry'], - str: 'test', - }; - const data = await caller('klass', 'klass'); - const obj = new Parser().parse(data); - expect(obj).toMatchObject(expected); - const prototype = Object.getPrototypeOf(obj); - expect(prototype).toHaveProperty('__module__', 'klass'); - expect(prototype).toHaveProperty('__name__', 'MyClass'); - }); + describe('klass', () => { + it('correctly unpickl class', async () => { + const expected = { + array: [1, true, false, null, 4294967295], + fruits: ['apple', 'banana', 'cherry'], + str: 'test', + }; + const data = await caller('klass', 'klass'); + const obj = new Parser().parse(data); + expect(obj).toMatchObject(expected); + const prototype = Object.getPrototypeOf(obj); + expect(prototype).toHaveProperty('__module__', 'klass'); + expect(prototype).toHaveProperty('__name__', 'MyClass'); + }); - it('correctly unpickl reduce', async () => { - const expected = ['379', 'acd']; - const data = await caller('klass', 'reduce'); - const obj = new Parser().parse<{ - args: typeof expected; - }>(data); - expect(obj.args).toStrictEqual(expected); - const prototype = Object.getPrototypeOf(obj); - expect(prototype).toHaveProperty('__module__', 'klass'); - expect(prototype).toHaveProperty('__name__', 'Reduce'); - }); + it('correctly unpickl reduce', async () => { + const expected = ['379', 'acd']; + const data = await caller('klass', 'reduce'); + const obj = new Parser().parse<{ + args: typeof expected; + }>(data); + expect(obj.args).toStrictEqual(expected); + const prototype = Object.getPrototypeOf(obj); + expect(prototype).toHaveProperty('__module__', 'klass'); + expect(prototype).toHaveProperty('__name__', 'Reduce'); + }); - it('correctly unpickl with customized reduce', async () => { - const expected = ['379', 'acd'].join(','); - const data = await caller('klass', 'reduce'); - const obj = new Parser({ - nameResolver: { - resolve: - () => - (...args) => - args.join(','), - }, - }).parse(data); - expect(obj).toStrictEqual(expected); - }); + it('correctly unpickl with customized reduce', async () => { + const expected = ['379', 'acd'].join(','); + const data = await caller('klass', 'reduce'); + const obj = new Parser({ + nameResolver: { + resolve: + () => + (...args) => + args.join(','), + }, + }).parse(data); + expect(obj).toStrictEqual(expected); + }); - it('with NameRegistry', async () => { - class MyClass { - // eslint-disable-next-line @typescript-eslint/no-explicit-any - array: any[] = []; - fruits: string[] = []; - str: string | undefined; - } - - const expected = new MyClass(); - expected.array = [1, true, false, null, 4294967295]; - expected.fruits = ['apple', 'banana', 'cherry']; - expected.str = 'test'; - - const registry = new NameRegistry().register('klass', 'MyClass', MyClass); - - const data = await caller('klass', 'klass'); - const obj = new Parser({ - nameResolver: registry, - }).parse(data); - expect(obj).toStrictEqual(expected); - }); -}); + it('with NameRegistry', async () => { + class MyClass { + // eslint-disable-next-line @typescript-eslint/no-explicit-any + array: any[] = []; + fruits: string[] = []; + str: string | undefined; + } -describe('long', () => { - it('correctly unpickl long4', async () => { - const expected = - '398975380292520334652879605459872453583337697310312075368132832107577484344336733631267656030491697975730319706482065334076861898105601071965378874928385172947532987811819170262577671821213107453864961626619206677733226746811432113436292476143879109827341785450751168773699095642803230314850775984760865060210698176217969616820780411226838820532294252280967457872536334204782734733250992661824578272424009848363696719993178605177831220265480726240392211861785871121272419902398980354868722930075005775349612053991637884090384318432325292289645358976894389564741519731777181442461757394830856205901222601730755928883177016176597856688472084971520'; - const data = await caller('long', 'long4'); - const obj = new Parser().parse(data); - expect(obj.toString()).toStrictEqual(expected); - }); -}); + const expected = new MyClass(); + expected.array = [1, true, false, null, 4294967295]; + expected.fruits = ['apple', 'banana', 'cherry']; + expected.str = 'test'; -describe('protocol5', () => { - it('correctly unpickl bytearray8', async () => { - const expected = Buffer.from([1, 2, 3]); - const data = await caller('protocol5', 'bytearray8'); - const obj = new Parser().parse(data); - expect(obj).toStrictEqual(expected); - }); + const registry = new NameRegistry().register('klass', 'MyClass', MyClass); - it('correctly unpickl next_buffer', async () => { - const expected = 123; - const data = await caller('protocol5', 'next_buffer'); - const obj = new Parser({ - buffers: (function* () { - yield expected; - })(), - }).parse(data); - expect(obj).toStrictEqual(expected); + const data = await caller('klass', 'klass'); + const obj = new Parser({ + nameResolver: registry, + }).parse(data); + expect(obj).toStrictEqual(expected); + }); }); - it('correctly unpickl multi_next_buffer', async () => { - const expected = [123, 'str']; - const data = await caller('protocol5', 'multi_next_buffer'); - const obj = new Parser({ - buffers: expected.values(), - }).parse(data); - expect(obj).toStrictEqual(expected); + describe('long', () => { + it('correctly unpickl long4', async () => { + const expected = + '398975380292520334652879605459872453583337697310312075368132832107577484344336733631267656030491697975730319706482065334076861898105601071965378874928385172947532987811819170262577671821213107453864961626619206677733226746811432113436292476143879109827341785450751168773699095642803230314850775984760865060210698176217969616820780411226838820532294252280967457872536334204782734733250992661824578272424009848363696719993178605177831220265480726240392211861785871121272419902398980354868722930075005775349612053991637884090384318432325292289645358976894389564741519731777181442461757394830856205901222601730755928883177016176597856688472084971520'; + const data = await caller('long', 'long4'); + const obj = new Parser().parse(data); + expect(obj.toString()).toStrictEqual(expected); + }); }); - it('correctly unpickl readonly_buffer', async () => { - const expected = 123; - const data = await caller('protocol5', 'readonly_buffer'); - const obj = new Parser({ - buffers: (function* () { - yield expected; - })(), - }).parse(data); - expect(obj).toStrictEqual(expected); - }); + describe('protocol5', () => { + it('correctly unpickl bytearray8', async () => { + const expected = Buffer.from([1, 2, 3]); + const data = await caller('protocol5', 'bytearray8'); + const obj = new Parser().parse(data); + expect(obj).toStrictEqual(expected); + }); - it('correctly unpickl next_buffer_and_readonly_buffer', async () => { - const expected = [123, [1, '22', null]]; - const data = await caller('protocol5', 'next_buffer_and_readonly_buffer'); - const obj = new Parser({ - buffers: expected.values(), - }).parse(data); - expect(obj).toStrictEqual(expected); - }); + it('correctly unpickl next_buffer', async () => { + const expected = 123; + const data = await caller('protocol5', 'next_buffer'); + const obj = new Parser({ + buffers: (function* () { + yield expected; + })(), + }).parse(data); + expect(obj).toStrictEqual(expected); + }); - it('correctly unpickl next_buffer_with_reduce_ex', async () => { - class mybytearray { - public args: number[]; - constructor(args: number[]) { - this.args = args; - } - static __reduce_ex__(args: number[]) { - return new mybytearray(args); + it('correctly unpickl multi_next_buffer', async () => { + const expected = [123, 'str']; + const data = await caller('protocol5', 'multi_next_buffer'); + const obj = new Parser({ + buffers: expected.values(), + }).parse(data); + expect(obj).toStrictEqual(expected); + }); + + it('correctly unpickl readonly_buffer', async () => { + const expected = 123; + const data = await caller('protocol5', 'readonly_buffer'); + const obj = new Parser({ + buffers: (function* () { + yield expected; + })(), + }).parse(data); + expect(obj).toStrictEqual(expected); + }); + + it('correctly unpickl next_buffer_and_readonly_buffer', async () => { + const expected = [123, [1, '22', null]]; + const data = await caller('protocol5', 'next_buffer_and_readonly_buffer'); + const obj = new Parser({ + buffers: expected.values(), + }).parse(data); + expect(obj).toStrictEqual(expected); + }); + + it('correctly unpickl next_buffer_with_reduce_ex', async () => { + class mybytearray { + public args: number[]; + constructor(args: number[]) { + this.args = args; + } + static __reduce_ex__(args: number[]) { + return new mybytearray(args); + } } - } - const externalData = [1, 2, 3, 4]; - const expected = new mybytearray(externalData); - const registry = new NameRegistry(); - registry.register('protocol5', 'mybytearray', mybytearray.__reduce_ex__); - const data = await caller('protocol5', 'next_buffer_with_reduce_ex'); - const obj = new Parser({ - nameResolver: registry, - buffers: (function* () { - yield externalData; - })(), - }).parse(data); - expect(obj).toStrictEqual(expected); + const externalData = [1, 2, 3, 4]; + const expected = new mybytearray(externalData); + const registry = new NameRegistry(); + registry.register('protocol5', 'mybytearray', mybytearray.__reduce_ex__); + const data = await caller('protocol5', 'next_buffer_with_reduce_ex'); + const obj = new Parser({ + nameResolver: registry, + buffers: (function* () { + yield externalData; + })(), + }).parse(data); + expect(obj).toStrictEqual(expected); + }); }); -}); -describe('dict', () => { - it.each(PROTOCOLS)('correctly unpickl emptydict', async (protocol) => { - const data = await caller('dict', 'emptydict', protocol); - const obj = new Parser().parse(data); - expect(obj).toStrictEqual({}); - }); + describe('dict', () => { + it.each(PROTOCOLS)('correctly unpickl emptydict', async (protocol) => { + const data = await caller('dict', 'emptydict', protocol); + const obj = new Parser().parse(data); + expect(obj).toStrictEqual({}); + }); - it.each(PROTOCOLS)('correctly unpickl emptydict with Map (protocol: %d)', async (protocol) => { - const data = await caller('dict', 'emptydict', protocol); - const obj = new Parser({ - unpicklingTypeOfDictionary: 'Map', - }).parse(data); - expect(obj).toStrictEqual(new Map()); - }); + it.each(PROTOCOLS)('correctly unpickl emptydict with Map (protocol: %d)', async (protocol) => { + const data = await caller('dict', 'emptydict', protocol); + const obj = new Parser({ + unpicklingTypeOfDictionary: 'Map', + }).parse(data); + expect(obj).toStrictEqual(new Map()); + }); - it.each(PROTOCOLS)('correctly unpickl dict w/ data', async (protocol) => { - const data = await caller('dict', 'dict1', protocol); - const obj = new Parser().parse(data); - expect(obj).toStrictEqual({ - key: 'foo', + it.each(PROTOCOLS)('correctly unpickl dict w/ data', async (protocol) => { + const data = await caller('dict', 'dict1', protocol); + const obj = new Parser().parse(data); + expect(obj).toStrictEqual({ + key: 'foo', + }); + }); + + it.each(PROTOCOLS)('correctly unpickl dict w/ multidata', async (protocol) => { + const data = await caller('dict', 'dict2', protocol); + const obj = new Parser().parse(data); + expect(obj).toStrictEqual({ + key: 'foo', + key2: 123, + key3: {}, + }); }); }); - it.each(PROTOCOLS)('correctly unpickl dict w/ multidata', async (protocol) => { - const data = await caller('dict', 'dict2', protocol); - const obj = new Parser().parse(data); - expect(obj).toStrictEqual({ - key: 'foo', - key2: 123, - key3: {}, + describe('set', () => { + it.each(['0', '1', '2'] as const)('correctly unpickl emptyset with p/ 0,1,2', async (protocol) => { + const registry = new NameRegistry().register('__builtin__', 'set', Array.from); + const data = await caller('set', 'emptyset', protocol); + const obj = new Parser({ + nameResolver: registry, + }).parse(data); + expect(obj).toStrictEqual([]); + }); + + it('correctly unpickl emptyset with p/ 3', async () => { + const registry = new NameRegistry().register('builtins', 'set', Array.from); + const data = await caller('set', 'emptyset', '3'); + const obj = new Parser({ + nameResolver: registry, + unpicklingTypeOfSet: 'Set', + }).parse(data); + expect(obj).toStrictEqual([]); + }); + + it('correctly unpickl emptyset with p/ 5', async () => { + const data = await caller('set', 'emptyset', '5'); + const obj = new Parser({ + unpicklingTypeOfSet: 'Set', + }).parse(data); + expect(obj).toStrictEqual(new Set()); + }); + + it('correctly unpickl emptyset with p/ 5 and array', async () => { + const data = await caller('set', 'emptyset', '5'); + const obj = new Parser({ + unpicklingTypeOfSet: 'array', + }).parse(data); + expect(obj).toStrictEqual([]); + }); + + it('correctly unpickl set with data with p/ 5', async () => { + const data = await caller('set', 'set1', '5'); + const obj = new Parser({ + unpicklingTypeOfSet: 'Set', + }).parse(data); + expect(obj).toStrictEqual(new Set(['apple', 'banana', 'cherry'])); + }); + + it('correctly unpickl frozenset', async () => { + const data = await caller('set', 'frozenset1'); + const obj = new Parser({ + unpicklingTypeOfSet: 'Set', + }).parse(data); + expect(obj).toStrictEqual(new Set([1, 2])); }); }); }); -describe('set', () => { - it.each(['0', '1', '2'] as const)('correctly unpickl emptyset with p/ 0,1,2', async (protocol) => { - const registry = new NameRegistry().register('__builtin__', 'set', Array.from); - const data = await caller('set', 'emptyset', protocol); - const obj = new Parser({ - nameResolver: registry, - }).parse(data); - expect(obj).toStrictEqual([]); - }); +describe('Pickler', () => { + describe('Integration with Python Unpickler', () => { + const PROTOCOLS = [1, 2, 3, 4, 5]; - it('correctly unpickl emptyset with p/ 3', async () => { - const registry = new NameRegistry().register('builtins', 'set', Array.from); - const data = await caller('set', 'emptyset', '3'); - const obj = new Parser({ - nameResolver: registry, - unpicklingTypeOfSet: 'Set', - }).parse(data); - expect(obj).toStrictEqual([]); - }); + const integrationTestCases = [ + ['a simple list', [1, 'hello', true], "[1, 'hello', True]"], + ['a simple dict', { a: 1, b: null }, "{'a': 1, 'b': None}"], + ['a nested structure', { data: [1, { value: 3.14 }] }, "{'data': [1, {'value': 3.14}]}"], + ['a structure with a BigInt', { big: 123456789n }, "{'big': 123456789}"], + [ + 'a self-referencing list', + ((): unknown[] => { + const l: unknown[] = [1]; + l.push(l); + return l; + })(), + '[1, [...]]', + ], + ] as const; - it('correctly unpickl emptyset with p/ 5', async () => { - const data = await caller('set', 'emptyset', '5'); - const obj = new Parser({ - unpicklingTypeOfSet: 'Set', - }).parse(data); - expect(obj).toStrictEqual(new Set()); - }); + const testCases = integrationTestCases.flatMap(([name, value, expectedRepr]) => + PROTOCOLS.map((p) => [name, value, p, expectedRepr]), + ); - it('correctly unpickl emptyset with p/ 5 and array', async () => { - const data = await caller('set', 'emptyset', '5'); - const obj = new Parser({ - unpicklingTypeOfSet: 'array', - }).parse(data); - expect(obj).toStrictEqual([]); - }); + it.each(testCases)( + 'correctly generates pickle for %s (protocol %p)', + async (name, value, protocol, expectedRepr) => { + const pickler = new Pickler({ protocol: protocol as number }); + const pickledData = pickler.dump(value); - it('correctly unpickl set with data with p/ 5', async () => { - const data = await caller('set', 'set1', '5'); - const obj = new Parser({ - unpicklingTypeOfSet: 'Set', - }).parse(data); - expect(obj).toStrictEqual(new Set(['apple', 'banana', 'cherry'])); - }); + const pythonRepr = await pythonUnpickle(pickledData); - it('correctly unpickl frozenset', async () => { - const data = await caller('set', 'frozenset1'); - const obj = new Parser({ - unpicklingTypeOfSet: 'Set', - }).parse(data); - expect(obj).toStrictEqual(new Set([1, 2])); + if (name === 'a self-referencing list') { + expect(pythonRepr).toMatch(/^\[1, \]$|^\[1, \[...\]\]$/); + } else { + expect(pythonRepr).toBe(expectedRepr); + } + }, + ); }); }); diff --git a/test/parser.test.ts b/test/parser.test.ts index f6573cc..36f0f40 100644 --- a/test/parser.test.ts +++ b/test/parser.test.ts @@ -136,5 +136,54 @@ describe('Parser', () => { const pkl = new Uint8Array([OP.PROTO, 6, OP.STOP]); expect(() => parser.parse(pkl)).toThrow("Unsupported protocol version '6'."); }); + + it('should correctly parse a dictionary into a Map with non-string keys', () => { + // This simulates a pickle stream for a dictionary: {1: "one", null: "is null"} + // using protocol 2. + // The stream is: PROTO 2, MARK, BININT1 1, BINUNICODE "one", NONE, BINUNICODE "is null", DICT, STOP + const pkl = new Uint8Array([ + OP.PROTO, + 2, // Protocol header + OP.MARK, // Mark for dictionary items + OP.BININT1, + 1, // Key: 1 (integer) + OP.BINUNICODE, + 3, + 0, + 0, + 0, + 0x6f, + 0x6e, + 0x65, // Value: "one" + OP.NONE, // Key: null + OP.BINUNICODE, + 7, + 0, + 0, + 0, + 0x69, + 0x73, + 0x20, + 0x6e, + 0x75, + 0x6c, + 0x6c, // Value: "is null" + OP.DICT, // Build dictionary from items + OP.STOP, // Stop + ]); + + // Configure the parser to produce Maps + const parser = new Parser({ + unpicklingTypeOfDictionary: 'Map', + }); + + const result = parser.parse>(pkl); + + const expected = new Map(); + expected.set(1, 'one'); + expected.set(null, 'is null'); + + expect(result).toStrictEqual(expected); + }); }); }); diff --git a/test/pickler.test.ts b/test/pickler.test.ts new file mode 100644 index 0000000..4b35ccb --- /dev/null +++ b/test/pickler.test.ts @@ -0,0 +1,184 @@ +import { Pickler } from '../src/pickler'; +import { Parser } from '../src/parser'; +import { OP } from '../src/opcode'; + +describe('Pickler', () => { + const PROTOCOLS = [0, 1, 2, 3, 4, 5] as const; + const PROTOCOLS_WITH_CYCLE_SUPPORT = PROTOCOLS.filter((p) => p >= 1); + + describe('#dump() - Primitives', () => { + const primitiveValues = [ + ['null', null], + ['boolean true', true], + ['boolean false', false], + ['a small positive integer (1 byte)', 123], + ['a medium integer (2 bytes)', 1000], + ['a larger integer (4 bytes)', 100000], + ['a negative integer', -500500], + ['zero', 0], + ['a float', 3.14], + ['an empty string', ''], + ['an ASCII string', 'hello world'], + ['a UTF-8 string', '你好, world!'], + ['medium-size string', 'a'.repeat(500)], + ['a big BigInt', 1234567891234567n], + ['a small BigInt', 1n], + ] as const; + + const primitiveTestCases = primitiveValues.flatMap(([testName, value]) => { + const applicableProtocols = typeof value === 'bigint' ? PROTOCOLS_WITH_CYCLE_SUPPORT : PROTOCOLS; + return applicableProtocols.map((protocol) => [testName, value, protocol] as const); + }); + + it.each(primitiveTestCases)('correctly pickles and unpickles %s (protocol %p)', (testName, value, protocol) => { + // FIX: Explicitly cast protocol to number + const pickler = new Pickler({ protocol: protocol as number }); + const parser = new Parser(); + + const pickledData = pickler.dump(value); + const unpickledData = parser.parse(pickledData); + + if (typeof value === 'bigint') { + expect(BigInt(unpickledData as number | bigint)).toBe(value); + } else { + expect(unpickledData).toStrictEqual(value); + } + }); + }); + + describe('#dump() - Containers', () => { + const objectLikeCases = [ + ['an empty array', []], + ['an array of numbers', [1, 2, 3]], + ['an array of mixed primitives', ['a', null, true, 1.5]], + ['a nested array', [1, [2, 3], 4]], + ['an empty object', {}], + ['a simple object', { a: 1, b: 'hello' }], + ['a nested object', { a: 1, b: { c: 2 } }], + ]; + + it.each(objectLikeCases.flatMap(([name, value]) => PROTOCOLS.map((p) => [name, value, p])))( + 'correctly pickles/unpickles %s (protocol %p) to an Object', + (name, value, protocol) => { + // FIX: Explicitly cast protocol to number + const pickler = new Pickler({ protocol: protocol as number }); + const parser = new Parser(); + + const pickledData = pickler.dump(value); + const unpickledData = parser.parse(pickledData); + + expect(unpickledData).toStrictEqual(value); + }, + ); + + const mapLikeCases = [ + ['an empty Map', new Map()], + [ + 'a Map with mixed key types', + new Map([ + ['a', 1], + [2, 'b'], + [null, true], + ]), + ], + ]; + + it.each(mapLikeCases.flatMap(([name, value]) => PROTOCOLS.map((p) => [name, value, p])))( + 'correctly pickles/unpickles %s (protocol %p) to a Map', + (name, value, protocol) => { + // FIX: Explicitly cast protocol to number + const pickler = new Pickler({ protocol: protocol as number }); + const parser = new Parser({ unpicklingTypeOfDictionary: 'Map' }); + + const pickledData = pickler.dump(value); + const unpickledData = parser.parse(pickledData); + + expect(unpickledData).toStrictEqual(value); + }, + ); + + it.each(PROTOCOLS_WITH_CYCLE_SUPPORT.map((p) => [p]))( + 'correctly pickles a self-referencing array (protocol %p)', + (protocol) => { + // FIX: Explicitly cast protocol to number + const pickler = new Pickler({ protocol: protocol as number }); + const parser = new Parser(); + + const arr: unknown[] = [1]; + arr.push(arr); + + const pickledData = pickler.dump(arr); + const unpickledData = parser.parse(pickledData); + + expect(unpickledData[0]).toBe(1); + expect(unpickledData[1]).toBe(unpickledData); + }, + ); + + it.each(PROTOCOLS_WITH_CYCLE_SUPPORT.map((p) => [p]))( + 'correctly pickles a self-referencing object (protocol %p)', + (protocol) => { + // FIX: Explicitly cast protocol to number + const pickler = new Pickler({ protocol: protocol as number }); + const parser = new Parser({ unpicklingTypeOfDictionary: 'Map' }); + + const obj = new Map(); + obj.set('a', 1); + obj.set('self', obj); + + const pickledData = pickler.dump(obj); + const unpickledData = parser.parse>(pickledData); + + expect(unpickledData.get('a')).toBe(1); + expect(unpickledData.get('self')).toBe(unpickledData); + }, + ); + + it.each(PROTOCOLS_WITH_CYCLE_SUPPORT.map((p) => [p]))( + 'correctly pickles a shared empty container (protocol %p)', + (protocol) => { + const pickler = new Pickler({ protocol }); + const parser = new Parser(); + + const sharedEmptyList: unknown[] = []; + const data = { + a: sharedEmptyList, + b: sharedEmptyList, + }; + + const pickledData = pickler.dump(data); + const unpickledData = parser.parse<{ a: unknown[]; b: unknown[] }>(pickledData); + + // 1. Check the structure + expect(unpickledData.a).toEqual([]); + expect(unpickledData.b).toEqual([]); + + // 2. Check that 'a' and 'b' point to the *exact same* object in memory + expect(unpickledData.a).toBe(unpickledData.b); + }, + ); + }); + + describe('Protocol Handling', () => { + // This section is fine + it('writes protocol header for protocol >= 2', () => { + const pickler = new Pickler({ protocol: 4 }); + const data = pickler.dump(null); + expect(data[0]).toBe(OP.PROTO); + expect(data[1]).toBe(4); + expect(data[2]).toBe(OP.NONE); + expect(data[3]).toBe(OP.STOP); + }); + + it('does not write protocol header for protocol < 2', () => { + const pickler = new Pickler({ protocol: 1 }); + const data = pickler.dump(null); + expect(data[0]).toBe(OP.NONE); + expect(data[1]).toBe(OP.STOP); + }); + + it('throws an error for invalid protocol version', () => { + expect(() => new Pickler({ protocol: 6 }).dump(null)).toThrow('Invalid protocol version: 6'); + }); + }); +}); diff --git a/test/writer.test.ts b/test/writer.test.ts new file mode 100644 index 0000000..9449b66 --- /dev/null +++ b/test/writer.test.ts @@ -0,0 +1,75 @@ +import { BufferReader } from '../src/reader'; +import { BufferWriter } from '../src/writer'; + +describe('BufferWriter', () => { + let writer: BufferWriter; + + beforeEach(() => { + writer = new BufferWriter(4); + }); + + const testCases = [ + ['byte', 0xff], + ['bytes', new Uint8Array([1, 2, 3, 4, 5, 6, 7, 8])], + ['uint16', 0x0302], + ['int32', -32767], + ['uint32', 0x05040302], + ['uint64', 281474976710656n], + ['float64', Math.PI], + ['line', '123qwe'], + ] as const; + + it.each(testCases)( + 'writes %s', + ( + method: M, + value: Parameters[0], + ...args: unknown[] + ) => { + // @ts-expect-error args has correct type due to testCases structure + writer[method](value, ...args); + + const reader = new BufferReader(writer.getBuffer()); + // @ts-expect-error same + const result = reader[method](value.length); // length for bytes and string + + if (typeof value === 'number' && !Number.isInteger(value)) { + expect(result).toBeCloseTo(value); + } else if (method === 'uint64') { + expect(BigInt(result as number)).toEqual(value); + } else { + expect(result).toEqual(value); + } + }, + ); + + it('writes string with utf-8', () => { + const expected = '123qwe你好こんにちは./'; + writer.string(expected, 'utf-8'); + + const buffer = writer.getBuffer(); + const decoder = new TextDecoder('utf-8'); + expect(decoder.decode(buffer)).toBe(expected); + }); + + it('writes string with ascii', () => { + const expected = '123qwe./%$'; + writer.string(expected, 'ascii'); + + const buffer = writer.getBuffer(); + const decoder = new TextDecoder('ascii'); + expect(decoder.decode(buffer)).toBe(expected); + }); + + it('chains multiple write calls correctly', () => { + writer.byte(0x01).uint16(0x0302).byte(0x04); + expect(writer.getBuffer()).toStrictEqual(new Uint8Array([0x01, 0x02, 0x03, 0x04])); + }); + + it('dynamically resizes the buffer when needed', () => { + const data = new Uint8Array([1, 2, 3, 4, 5, 6, 7, 8, 9]); + writer.bytes(data); + + expect(writer.getBuffer()).toStrictEqual(data); + }); +});