上篇文章使用chatGPT翻译了db_tutorial 文章,文中使用的是c语言开发; 这篇文章使用chatGPT根据db_tutorial中的c源码,使用golang进行重写, 测试的ruby代码使用python进行重写;同理其他语言也适用。

注:利用已有知识结构,通过chatGPT来生成另一种表达(现实中这种转换经常出现,比如一个基础知识点,嚼碎了,揉烂了,底层相通,表达方式不同,变了个花样玩,而且还能通过认知差来盈利,也许精细利己主义会利益最大化吧),使用AGI工具进行效率编码的一种小小实践。在实践过程中,chatGPT生成的代码不可能都能正常运行,需要调试下(特别是指针操作)。

整体实现代码:https://github.com/weedge/baby-db/tree/main/golang

主要的btree数据结构为leafNode 和 internalNode,叶子节点表数据存放在value中,id存放在key中,序列化和遍历操作需要额外偏移操作;这里仅实现简单的insert和select操作。

leafNode

internalNode

第一部分 - 简介和设置REPL

制作一个简单的 REPL(Read-Eval-Print Loop) golang 版本

package main

import (
	"bufio"
	"fmt"
	"os"
	"strings"
)

type InputBuffer struct {
	buffer       string
	bufferLength int
	inputLength  int
}

func newInputBuffer() *InputBuffer {
	return &InputBuffer{
		buffer:       "",
		bufferLength: 0,
		inputLength:  0,
	}
}

func printPrompt() {
	fmt.Print("db > ")
}

func readInput(reader *bufio.Reader, inputBuffer *InputBuffer) {
	// chatGPT init error, need to debug
	//reader := bufio.NewReader(os.Stdin)
	input, err := reader.ReadString('\n')
	if err != nil {
		fmt.Println("Error reading input: ", err.Error())
		os.Exit(1)
	}

	input = strings.TrimSuffix(input, "\n")
	inputBuffer.inputLength = len(input)
	inputBuffer.buffer = input
}

func closeInputBuffer(inputBuffer *InputBuffer) {
	inputBuffer = nil
}

func main() {
	inputBuffer := newInputBuffer()
	reader := bufio.NewReader(os.Stdin)
	for {
		printPrompt()
		readInput(reader, inputBuffer)
		if len(inputBuffer.buffer) == 0 {
			continue
		}

		if inputBuffer.buffer == ".exit" {
			closeInputBuffer(inputBuffer)
			os.Exit(0)
		} else {
			fmt.Printf("Unrecognized command '%s'.\n", inputBuffer.buffer)
		}
	}
}

运行结果:

go run golang/1.go
db > .tables
Unrecognized command '.tables'.
db > .exit

第二部分 - 世界上最简单的SQL编译器和虚拟机

golang 版本

package main

import (
	"bufio"
	"fmt"
	"os"
	"strings"
)

type InputBuffer struct {
	buffer       string
	bufferLength int
	inputLength  int
}

type MetaCommandResult int

const (
	META_COMMAND_SUCCESS MetaCommandResult = iota
	META_COMMAND_UNRECOGNIZED_COMMAND
)

type PrepareResult int

const (
	PREPARE_SUCCESS PrepareResult = iota
	PREPARE_UNRECOGNIZED_STATEMENT
)

type StatementType int

const (
	STATEMENT_INSERT StatementType = iota
	STATEMENT_SELECT
)

type Statement struct {
	Type StatementType
}

func newInputBuffer() *InputBuffer {
	return &InputBuffer{
		buffer:       "",
		bufferLength: 0,
		inputLength:  0,
	}
}

func printPrompt() {
	fmt.Print("db > ")
}

func readInput(reader *bufio.Reader, inputBuffer *InputBuffer) {
	// chatGPT init error, need to debug
	//reader := bufio.NewReader(os.Stdin)
	input, err := reader.ReadString('\n')
	if err != nil {
		fmt.Println("Error reading input: ", err.Error())
		os.Exit(1)
	}

	input = strings.TrimSuffix(input, "\n")
	inputBuffer.inputLength = len(input)
	inputBuffer.buffer = input
}

func doMetaCommand(inputBuffer *InputBuffer) MetaCommandResult {
	if inputBuffer.buffer == ".exit" {
		os.Exit(0)
	}
	return META_COMMAND_UNRECOGNIZED_COMMAND
}

func prepareStatement(inputBuffer *InputBuffer, statement *Statement) PrepareResult {
	if strings.HasPrefix(inputBuffer.buffer, "insert") {
		statement.Type = STATEMENT_INSERT
		return PREPARE_SUCCESS
	}
	if inputBuffer.buffer == "select" {
		statement.Type = STATEMENT_SELECT
		return PREPARE_SUCCESS
	}

	return PREPARE_UNRECOGNIZED_STATEMENT
}

func executeStatement(statement *Statement) {
	switch statement.Type {
	case STATEMENT_INSERT:
		fmt.Println("This is where we would do an insert.")
	case STATEMENT_SELECT:
		fmt.Println("This is where we would do a select.")
	}
}

func main() {
	inputBuffer := newInputBuffer()
	reader := bufio.NewReader(os.Stdin)
	for {
		printPrompt()
		readInput(reader, inputBuffer)
		if len(inputBuffer.buffer) == 0 {
			continue
		}

		if inputBuffer.buffer[0] == '.' {
			switch doMetaCommand(inputBuffer) {
			case META_COMMAND_SUCCESS:
				continue
			case META_COMMAND_UNRECOGNIZED_COMMAND:
				fmt.Printf("Unrecognized command '%s'\n", inputBuffer.buffer)
				continue
			}
		}

		var statement Statement
		switch prepareStatement(inputBuffer, &statement) {
		case PREPARE_SUCCESS:
			break
		case PREPARE_UNRECOGNIZED_STATEMENT:
			fmt.Printf("Unrecognized keyword at start of '%s'.\n", inputBuffer.buffer)
			continue
		}

		executeStatement(&statement)
		fmt.Println("Executed.")
	}
}

运行结果:

go run golang/2.go
db > insert foo bar
This is where we would do an insert.
Executed.
db >
db > delete foo
Unrecognized keyword at start of 'delete foo'.
db > select
This is where we would do a select.
Executed.
db > .tables
Unrecognized command '.tables'
db > .exit

第三部分 - 内存中、追加方式、单表数据库

golang 版本

package main

import (
	"bufio"
	"fmt"
	"os"
	"strings"
	"unsafe"
)

type InputBuffer struct {
	buffer       string
	bufferLength int
	inputLength  int
}

type MetaCommandResult int

const (
	META_COMMAND_SUCCESS MetaCommandResult = iota
	META_COMMAND_UNRECOGNIZED_COMMAND
)

type PrepareResult int

const (
	PREPARE_SUCCESS PrepareResult = iota
	PREPARE_SYNTAX_ERROR
	PREPARE_UNRECOGNIZED_STATEMENT
)

type StatementType int

const (
	STATEMENT_INSERT StatementType = iota
	STATEMENT_SELECT
)

const (
	COLUMN_USERNAME_SIZE = 32
	COLUMN_EMAIL_SIZE    = 255
)

type Row struct {
	id       uint32
	username [COLUMN_USERNAME_SIZE]byte
	email    [COLUMN_EMAIL_SIZE]byte
}

type Statement struct {
	Type        StatementType
	rowToInsert Row //only used by insert statement
}

const (
	ID_SIZE         = int(unsafe.Sizeof(uint32(0)))
	USERNAME_SIZE   = int(unsafe.Sizeof([COLUMN_USERNAME_SIZE]byte{}))
	EMAIL_SIZE      = int(unsafe.Sizeof([COLUMN_EMAIL_SIZE]byte{}))
	ID_OFFSET       = 0
	USERNAME_OFFSET = ID_OFFSET + ID_SIZE
	EMAIL_OFFSET    = USERNAME_OFFSET + USERNAME_SIZE
	ROW_SIZE        = ID_SIZE + USERNAME_SIZE + EMAIL_SIZE
)

const (
	PAGE_SIZE       = 4096
	TABLE_MAX_PAGES = 100
	ROWS_PER_PAGE   = PAGE_SIZE / ROW_SIZE
	TABLE_MAX_ROWS  = ROWS_PER_PAGE * TABLE_MAX_PAGES
)

type Table struct {
	numRows uint32
	pages   [TABLE_MAX_PAGES][]byte
}

func printRow(row *Row) {
	fmt.Printf("(%d, %s, %s)\n", row.id, strings.TrimRight(string(row.username[:]), "\x00"), strings.TrimRight(string(row.email[:]), "\x00"))
}

func serializeRow(source *Row, destination []byte) {
	copy(destination[ID_OFFSET:], (*(*[ID_SIZE]byte)(unsafe.Pointer(&source.id)))[:])
	copy(destination[USERNAME_OFFSET:], source.username[:])
	copy(destination[EMAIL_OFFSET:], source.email[:])
}

func deserializeRow(source []byte, destination *Row) {
	destination.id = *(*uint32)(unsafe.Pointer(&source[ID_OFFSET]))
	copy(destination.username[:], source[USERNAME_OFFSET:USERNAME_OFFSET+USERNAME_SIZE])
	copy(destination.email[:], source[EMAIL_OFFSET:EMAIL_OFFSET+EMAIL_SIZE])
}

func rowSlot(table *Table, rowNum uint32) []byte {
	pageNum := rowNum / uint32(ROWS_PER_PAGE)
	page := table.pages[pageNum]
	if page == nil {
		page = make([]byte, PAGE_SIZE)
		table.pages[pageNum] = page
	}
	rowOffset := rowNum % uint32(ROWS_PER_PAGE)
	byteOffset := rowOffset * uint32(ROW_SIZE)
	return page[byteOffset : byteOffset+uint32(ROW_SIZE)]
}

func newTable() *Table {
	table := new(Table)
	table.numRows = 0
	return table
}

func freeTable(table *Table) {
	for i := 0; i < TABLE_MAX_PAGES; i++ {
		if table.pages[i] != nil {
			table.pages[i] = nil
		}
	}
}

func doMetaCommand(inputBuffer *InputBuffer, table *Table) MetaCommandResult {
	if inputBuffer.buffer == ".exit" {
		closeInputBuffer(inputBuffer)
		freeTable(table)
		os.Exit(0)
	}
	return META_COMMAND_UNRECOGNIZED_COMMAND
}
func BytesToString(b []byte) string {
	p := unsafe.SliceData(b)
	return unsafe.String(p, len(b))
}
func StringToBytes(s string) []byte {
	p := unsafe.StringData(s)
	b := unsafe.Slice(p, len(s))
	return b
}

func prepareStatement(inputBuffer *InputBuffer, statement *Statement) PrepareResult {
	if strings.HasPrefix(inputBuffer.buffer, "insert") {
		statement.Type = STATEMENT_INSERT
		var username string
		var email string
		// chatGPT generate code, need debug
		//argsAssigned, _ := fmt.Sscanf(inputBuffer.buffer, "insert %d %s %s", &statement.rowToInsert.id, &statement.rowToInsert.username, &statement.rowToInsert.email)
		argsAssigned, _ := fmt.Sscanf(inputBuffer.buffer, "insert %d %s %s", &statement.rowToInsert.id, &username, &email)
		if argsAssigned < 3 {
			return PREPARE_SYNTAX_ERROR
		}
		copy(statement.rowToInsert.username[:], StringToBytes(username))
		copy(statement.rowToInsert.email[:], StringToBytes(email))
		return PREPARE_SUCCESS
	}
	if inputBuffer.buffer == "select" {
		statement.Type = STATEMENT_SELECT
		return PREPARE_SUCCESS
	}

	return PREPARE_UNRECOGNIZED_STATEMENT
}

func executeInsert(statement *Statement, table *Table) error {
	if table.numRows >= uint32(TABLE_MAX_ROWS) {
		err := fmt.Errorf("Error: Table full.")
		return err
	}

	rowToInsert := &statement.rowToInsert

	serializeRow(rowToInsert, rowSlot(table, table.numRows))
	table.numRows++
	return nil
}

func executeSelect(table *Table) {
	var row Row
	for i := uint32(0); i < table.numRows; i++ {
		deserializeRow(rowSlot(table, i), &row)
		printRow(&row)
	}
}

func executeStatement(statement *Statement, table *Table) error {
	switch statement.Type {
	case STATEMENT_INSERT:
		return executeInsert(statement, table)
	case STATEMENT_SELECT:
		executeSelect(table)
	}
	return nil
}

func closeInputBuffer(inputBuffer *InputBuffer) {
	inputBuffer = nil
}

func printPrompt() {
	fmt.Print("db > ")
}

func readInput(reader *bufio.Reader, inputBuffer *InputBuffer) {
	// chatGPT init error, need to debug
	//reader := bufio.NewReader(os.Stdin)
	input, err := reader.ReadString('\n')
	if err != nil {
		fmt.Println("Error reading input: ", err.Error())
		os.Exit(1)
	}

	input = strings.TrimSuffix(input, "\n")
	inputBuffer.inputLength = len(input)
	inputBuffer.buffer = input
}

func main() {
	table := newTable()
	inputBuffer := new(InputBuffer)
	reader := bufio.NewReader(os.Stdin)
	for {
		printPrompt()
		readInput(reader, inputBuffer)
		if len(inputBuffer.buffer) == 0 {
			continue
		}

		if inputBuffer.buffer[0] == '.' {
			switch doMetaCommand(inputBuffer, table) {
			case META_COMMAND_SUCCESS:
				continue
			case META_COMMAND_UNRECOGNIZED_COMMAND:
				fmt.Printf("Unrecognized command '%s'\n", inputBuffer.buffer)
				continue
			}
		}

		var statement Statement
		switch prepareStatement(inputBuffer, &statement) {
		case PREPARE_SUCCESS:
			break
		case PREPARE_SYNTAX_ERROR:
			fmt.Println("Syntax error. Could not parse statement.")
			continue
		case PREPARE_UNRECOGNIZED_STATEMENT:
			fmt.Printf("Unrecognized keyword at start of '%s'\n", inputBuffer.buffer)
			continue
		}

		if err := executeStatement(&statement, table); err != nil {
			fmt.Println(err.Error())
			continue
		}
		fmt.Println("Executed.")
	}
}

运行结果:

go run golang/3.go
db > insert 1 cstack foo@bar.com
Executed.
db > insert 2 bob bob@example.com
Executed.
db > select
(1, cstack, foo@bar.com)
(2, bob, bob@example.com)
Executed.
db > insert foo bar 1
Syntax error. Could not parse statement.
db > .exit

第四部分 - 我们的第一个测试(Bug)

使用golang重写:

package main

import (
	"bufio"
	"fmt"
	"os"
	"strconv"
	"strings"
	"unsafe"
)

const (
	COLUMN_USERNAME_SIZE = 32
	COLUMN_EMAIL_SIZE    = 255
	ID_SIZE              = 4
	USERNAME_SIZE        = COLUMN_USERNAME_SIZE + 1
	EMAIL_SIZE           = COLUMN_EMAIL_SIZE + 1
	ID_OFFSET            = 0
	USERNAME_OFFSET      = ID_OFFSET + ID_SIZE
	EMAIL_OFFSET         = USERNAME_OFFSET + USERNAME_SIZE
	ROW_SIZE             = ID_SIZE + USERNAME_SIZE + EMAIL_SIZE
	PAGE_SIZE            = 4096
	TABLE_MAX_PAGES      = 100
	ROWS_PER_PAGE        = PAGE_SIZE / ROW_SIZE
	TABLE_MAX_ROWS       = ROWS_PER_PAGE * TABLE_MAX_PAGES
)

type Row struct {
	id       uint32
	username [COLUMN_USERNAME_SIZE + 1]byte
	email    [COLUMN_EMAIL_SIZE + 1]byte
}

type Table struct {
	numRows uint32
	pages   [TABLE_MAX_PAGES][]byte
}

type InputBuffer struct {
	buffer       []byte
	bufferLength int
	inputLength  int
}

type StatementType int

const (
	STATEMENT_INSERT StatementType = iota
	STATEMENT_SELECT
)

type Statement struct {
	stmtType    StatementType
	rowToInsert Row
}

type ExecuteResult int

const (
	EXECUTE_SUCCESS ExecuteResult = iota
	EXECUTE_TABLE_FULL
)

type MetaCommandResult int

const (
	META_COMMAND_SUCCESS MetaCommandResult = iota
	META_COMMAND_UNRECOGNIZED_COMMAND
)

type PrepareResult int

const (
	PREPARE_SUCCESS PrepareResult = iota
	PREPARE_NEGATIVE_ID
	PREPARE_STRING_TOO_LONG
	PREPARE_SYNTAX_ERROR
	PREPARE_UNRECOGNIZED_STATEMENT
)

func newInputBuffer() *InputBuffer {
	return &InputBuffer{
		buffer: make([]byte, 0),
	}
}

func newTable() *Table {
	table := &Table{
		numRows: 0,
	}
	for i := 0; i < TABLE_MAX_PAGES; i++ {
		table.pages[i] = nil
	}
	return table
}

func printRow(row *Row) {
	fmt.Printf("(%d, %s, %s)\n", row.id, strings.TrimRight(string(row.username[:]), "\x00"), strings.TrimRight(string(row.email[:]), "\x00"))
}

func serializeRow(source *Row, destination []byte) {
	copy(destination[ID_OFFSET:], (*(*[ID_SIZE]byte)(unsafe.Pointer(&source.id)))[:])
	copy(destination[USERNAME_OFFSET:], source.username[:])
	copy(destination[EMAIL_OFFSET:], source.email[:])
}

func deserializeRow(source []byte, destination *Row) {
	destination.id = *(*uint32)(unsafe.Pointer(&source[ID_OFFSET]))
	copy(destination.username[:], source[USERNAME_OFFSET:USERNAME_OFFSET+USERNAME_SIZE])
	copy(destination.email[:], source[EMAIL_OFFSET:EMAIL_OFFSET+EMAIL_SIZE])
}

func rowSlot(table *Table, rowNum uint32) []byte {
	pageNum := rowNum / ROWS_PER_PAGE
	page := table.pages[pageNum]
	if page == nil {
		page = make([]byte, PAGE_SIZE)
		table.pages[pageNum] = page
	}
	rowOffset := rowNum % ROWS_PER_PAGE
	byteOffset := rowOffset * ROW_SIZE
	return page[byteOffset : byteOffset+ROW_SIZE]
}

func printPrompt() {
	fmt.Print("db > ")
}

func readInput(reader *bufio.Reader, inputBuffer *InputBuffer) {
	// chatGPT init error, need to debug
	//reader := bufio.NewReader(os.Stdin)
	input, err := reader.ReadString('\n')
	if err != nil {
		fmt.Println("Error reading input: ", err.Error())
		os.Exit(1)
	}
	// Remove newline character
	input = strings.TrimSpace(input)

	inputBuffer.buffer = []byte(input)
	inputBuffer.inputLength = len(inputBuffer.buffer)
}

func closeInputBuffer(inputBuffer *InputBuffer) {
	// Go has automatic garbage collection, so no explicit freeing is needed
}

func doMetaCommand(inputBuffer *InputBuffer, table *Table) MetaCommandResult {
	switch string(inputBuffer.buffer) {
	case ".exit":
		closeInputBuffer(inputBuffer)
		os.Exit(0)
	default:
		return META_COMMAND_UNRECOGNIZED_COMMAND
	}
	return META_COMMAND_SUCCESS
}

func prepareInsert(inputBuffer *InputBuffer, statement *Statement) PrepareResult {
	statement.stmtType = STATEMENT_INSERT

	tokens := strings.Fields(string(inputBuffer.buffer))

	if len(tokens) != 4 {
		return PREPARE_SYNTAX_ERROR
	}

	id, err := strconv.Atoi(tokens[1])
	if err != nil || id < 0 {
		return PREPARE_NEGATIVE_ID
	}

	if len(tokens[2]) > COLUMN_USERNAME_SIZE || len(tokens[3]) > COLUMN_EMAIL_SIZE {
		return PREPARE_STRING_TOO_LONG
	}

	statement.rowToInsert.id = uint32(id)
	copy(statement.rowToInsert.username[:], tokens[2])
	copy(statement.rowToInsert.email[:], tokens[3])

	return PREPARE_SUCCESS
}

func prepareStatement(inputBuffer *InputBuffer, statement *Statement) PrepareResult {
	tokens := strings.Fields(string(inputBuffer.buffer))

	if len(tokens) == 0 {
		return PREPARE_SUCCESS
	}

	switch tokens[0] {
	case "insert":
		return prepareInsert(inputBuffer, statement)
	case "select":
		statement.stmtType = STATEMENT_SELECT
		return PREPARE_SUCCESS
	default:
		return PREPARE_UNRECOGNIZED_STATEMENT
	}
}

func executeInsert(statement *Statement, table *Table) ExecuteResult {
	if table.numRows >= TABLE_MAX_ROWS {
		return EXECUTE_TABLE_FULL
	}

	rowToInsert := &statement.rowToInsert

	serializeRow(rowToInsert, rowSlot(table, table.numRows))
	table.numRows++

	return EXECUTE_SUCCESS
}

func executeSelect(statement *Statement, table *Table) ExecuteResult {
	var row Row
	for i := uint32(0); i < table.numRows; i++ {
		deserializeRow(rowSlot(table, i), &row)
		printRow(&row)
	}
	return EXECUTE_SUCCESS
}

func executeStatement(statement *Statement, table *Table) ExecuteResult {
	switch statement.stmtType {
	case STATEMENT_INSERT:
		return executeInsert(statement, table)
	case STATEMENT_SELECT:
		return executeSelect(statement, table)
	default:
		return EXECUTE_SUCCESS
	}
}

func main() {
	table := newTable()
	inputBuffer := newInputBuffer()
	reader := bufio.NewReader(os.Stdin)

	for {
		printPrompt()
		readInput(reader, inputBuffer)

		if inputBuffer.buffer[0] == '.' {
			switch doMetaCommand(inputBuffer, table) {
			case META_COMMAND_SUCCESS:
				continue
			case META_COMMAND_UNRECOGNIZED_COMMAND:
				fmt.Printf("Unrecognized command '%s'\n", inputBuffer.buffer)
				continue
			}
		}

		var statement Statement
		switch prepareStatement(inputBuffer, &statement) {
		case PREPARE_SUCCESS:
			break
		case PREPARE_NEGATIVE_ID:
			fmt.Println("ID must be positive.")
			continue
		case PREPARE_STRING_TOO_LONG:
			fmt.Println("String is too long.")
			continue
		case PREPARE_SYNTAX_ERROR:
			fmt.Println("Syntax error. Could not parse statement.")
			continue
		case PREPARE_UNRECOGNIZED_STATEMENT:
			fmt.Printf("Unrecognized keyword at start of '%s'.\n", inputBuffer.buffer)
			continue
		}

		switch executeStatement(&statement, table) {
		case EXECUTE_SUCCESS:
			fmt.Println("Executed.")
			break
		case EXECUTE_TABLE_FULL:
			fmt.Println("Error: Table full.")
			break
		}
	}
}

测试命令pipe: util.py

def run_script(commands,bin_file="./db",db_file=""):
    raw_output = None
    with subprocess.Popen([bin_file, db_file], stdin=subprocess.PIPE, stdout=subprocess.PIPE, text=True) as process:
        for command in commands:
            process.stdin.write(command + '\n')
        process.stdin.close()
        raw_output = process.stdout.read()
    return raw_output.splitlines()

使用python重写测试:

import subprocess
import sys
from util import run_script

# 测试插入和查询
def test_inserts_and_retrieves_row():
    result = run_script([
        "insert 1 user1 person1@example.com",
        "select",
        ".exit"
    ])
    expected_result = [
        "db > Executed.",
        "db > (1, user1, person1@example.com)",
        "Executed.",
        "db > "
    ]
    #print(f"result: {result}")
    assert result == expected_result, "Test failed"
    print(f"{sys._getframe().f_code.co_name} passed")

# Run the test
test_inserts_and_retrieves_row()

# 数据库现在可以容纳1400行,因为我们将最大页面数设置为100,并且一页可以容纳14行。
# 测试表已满的情况
def test_prints_error_message_when_table_is_full():
    script = [f"insert {i} user{i} person{i}@example.com" for i in range(1, 1402)]
    script.append(".exit")
    result = run_script(script)
    expected_result = "db > Error: Table full."
    assert result[-2] == expected_result, "Test failed"
    print(f"{sys._getframe().f_code.co_name} passed")

# Run the test
test_prints_error_message_when_table_is_full()


# 测试允许插入最大长度的字符串。
def test_allows_inserting_strings_that_are_maximum_length():
    long_username = "a" * 32
    long_email = "a" * 255
    script = [
        f"insert 1 {long_username} {long_email}",
        "select",
        ".exit",
    ]
    result = run_script(script)
    expected_result = [
        "db > Executed.",
        f"db > (1, {long_username}, {long_email})",
        "Executed.",
        "db > "
    ]
    assert result == expected_result, "Test failed"
    print(f"{sys._getframe().f_code.co_name} passed")

# Run the test
test_allows_inserting_strings_that_are_maximum_length()


# 测试如果字符串太长,则会打印错误消息。
def test_prints_error_message_if_strings_are_too_long():
    long_username = "a" * 33
    long_email = "a" * 256
    script = [
        f"insert 1 {long_username} {long_email}",
        "select",
        ".exit",
    ]
    result = run_script(script)
    expected_result = [
        "db > String is too long.",
        "db > Executed.",
        "db > "
    ]
    assert result == expected_result, "Test failed"
    print(f"{sys._getframe().f_code.co_name} passed")

# Run the test
test_prints_error_message_if_strings_are_too_long()

# 当尝试插入带有负 ID 的行时打印错误信息
def test_negative_id_error_message():
    script = [
        "insert -1 cstack foo@bar.com",
        "select",
        ".exit",
    ]
    result = run_script(script)
    expected_output = [
        "db > ID must be positive.",
        "db > Executed.",
        "db > ",
    ]
    assert result == expected_output, f"Expected: {expected_output}, but got: {result}"
    print(f"{sys._getframe().f_code.co_name} passed")

test_negative_id_error_message()

print("all tests passed.")

第五部分 - 持久化到磁盘

测试用例:

import sys,os
from util import run_script

# 测试数据库关闭后,数据是否在数据库中
def test_keeps_data_after_closing_connection(db_file):
    result1 = run_script([
        "insert 1 user1 person1@example.com",
        ".exit",
    ],db_file=db_file)
    assert result1 == [
        "db > Executed.",
        "db > ",
    ]

    result2 = run_script([
        "select",
        ".exit",
    ],db_file=db_file)
    #print(f"result2: {result2}")
    assert result2 == [
        "db > (1, user1, person1@example.com)",
        "Executed.",
        "db > ",
    ]
    print(f"{sys._getframe().f_code.co_name} passed")

if len(sys.argv)<2:
    print(f"need db file path")
    exit(0)
db_file = sys.argv[1]
if os.path.exists(db_file):
    os.remove(db_file)

test_keeps_data_after_closing_connection(db_file)

print("all tests passed.")

持久化 golang版本:

package main

import (
	"bufio"
	"fmt"
	"io"
	"os"
	"strconv"
	"strings"
	"unsafe"
)

const (
	COLUMN_USERNAME_SIZE = 32
	COLUMN_EMAIL_SIZE    = 255
	ID_SIZE              = 4
	USERNAME_SIZE        = COLUMN_USERNAME_SIZE + 1
	EMAIL_SIZE           = COLUMN_EMAIL_SIZE + 1
	ID_OFFSET            = 0
	USERNAME_OFFSET      = ID_OFFSET + ID_SIZE
	EMAIL_OFFSET         = USERNAME_OFFSET + USERNAME_SIZE
	ROW_SIZE             = ID_SIZE + USERNAME_SIZE + EMAIL_SIZE
	PAGE_SIZE            = 4096
	TABLE_MAX_PAGES      = 100
	ROWS_PER_PAGE        = PAGE_SIZE / ROW_SIZE
	TABLE_MAX_ROWS       = ROWS_PER_PAGE * TABLE_MAX_PAGES
)

type InputBuffer struct {
	buffer       string
	bufferLength int
	inputLength  int
}

type MetaCommandResult int

const (
	META_COMMAND_SUCCESS MetaCommandResult = iota
	META_COMMAND_UNRECOGNIZED_COMMAND
)

type PrepareResult int

const (
	PREPARE_SUCCESS PrepareResult = iota
	PREPARE_NEGATIVE_ID
	PREPARE_STRING_TOO_LONG
	PREPARE_SYNTAX_ERROR
	PREPARE_UNRECOGNIZED_STATEMENT
)

type StatementType int

const (
	STATEMENT_INSERT StatementType = iota
	STATEMENT_SELECT
)

type Row struct {
	id       uint32
	username [COLUMN_USERNAME_SIZE + 1]byte
	email    [COLUMN_EMAIL_SIZE + 1]byte
}

type Statement struct {
	typ         StatementType
	rowToInsert Row
}

type Pager struct {
	fileDescriptor *os.File
	fileLength     uint32
	pages          [TABLE_MAX_PAGES][]byte
}

type Table struct {
	numRows uint32
	pager   *Pager
}

type ExecuteResult int

const (
	EXECUTE_SUCCESS ExecuteResult = iota
	EXECUTE_TABLE_FULL
)

func newInputBuffer() *InputBuffer {
	buffer := ""
	return &InputBuffer{
		buffer:       buffer,
		bufferLength: 0,
		inputLength:  0,
	}
}

func printRow(row *Row) {
	//fmt.Printf("(%d, %s, %s)\n", row.id, row.username, row.email)
	fmt.Printf("(%d, %s, %s)\n", row.id, strings.TrimRight(string(row.username[:]), "\x00"), strings.TrimRight(string(row.email[:]), "\x00"))
}

func serializeRow(source *Row, destination []byte) {
	copy(destination[ID_OFFSET:], (*(*[ID_SIZE]byte)(unsafe.Pointer(&source.id)))[:])
	copy(destination[USERNAME_OFFSET:], source.username[:])
	copy(destination[EMAIL_OFFSET:], source.email[:])
}

func deserializeRow(source []byte, destination *Row) {
	destination.id = *(*uint32)(unsafe.Pointer(&source[ID_OFFSET]))
	copy(destination.username[:], source[USERNAME_OFFSET:USERNAME_OFFSET+USERNAME_SIZE])
	copy(destination.email[:], source[EMAIL_OFFSET:EMAIL_OFFSET+EMAIL_SIZE])
}

func getPage(pager *Pager, pageNum uint32) []byte {
	if pageNum > TABLE_MAX_PAGES {
		fmt.Printf("Tried to fetch page number out of bounds. %d > %d\n", pageNum, TABLE_MAX_PAGES)
		os.Exit(1)
	}

	if pager.pages[pageNum] == nil {
		page := make([]byte, PAGE_SIZE)
		numPages := pager.fileLength / PAGE_SIZE

		if pager.fileLength%PAGE_SIZE != 0 {
			numPages++
		}

		if pageNum <= numPages {
			_, err := pager.fileDescriptor.Seek(int64(pageNum*PAGE_SIZE), os.SEEK_SET)
			if err != nil {
				fmt.Printf("Error seeking: %v\n", err)
				os.Exit(1)
			}

			_, err = pager.fileDescriptor.Read(page)
			if err != nil && err != io.EOF {
				fmt.Printf("Error reading file: %v\n", err)
				os.Exit(1)
			}
		}

		pager.pages[pageNum] = page
	}

	return pager.pages[pageNum]
}

func rowSlot(table *Table, rowNum uint32) []byte {
	pageNum := rowNum / ROWS_PER_PAGE
	page := getPage(table.pager, pageNum)
	rowOffset := rowNum % ROWS_PER_PAGE
	byteOffset := rowOffset * ROW_SIZE
	return page[byteOffset : byteOffset+ROW_SIZE]
}

func pagerOpen(filename string) *Pager {
	fileDescriptor, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE, 0644)
	if err != nil {
		fmt.Printf("Unable to open file: %v\n", err)
		os.Exit(1)
	}

	fileLength, err := fileDescriptor.Seek(0, os.SEEK_END)
	if err != nil {
		fmt.Printf("Error seeking: %v\n", err)
		os.Exit(1)
	}

	pager := &Pager{
		fileDescriptor: fileDescriptor,
		fileLength:     uint32(fileLength),
	}

	for i := 0; i < TABLE_MAX_PAGES; i++ {
		pager.pages[i] = nil
	}

	return pager
}

func dbOpen(filename string) *Table {
	pager := pagerOpen(filename)
	numRows := pager.fileLength / ROW_SIZE

	table := &Table{
		numRows: numRows,
		pager:   pager,
	}

	return table
}

func pagerFlush(pager *Pager, pageNum uint32, size uint32) {
	if pager.pages[pageNum] == nil {
		fmt.Printf("Tried to flush null page\n")
		os.Exit(1)
	}

	offset, err := pager.fileDescriptor.Seek(int64(pageNum*PAGE_SIZE), os.SEEK_SET)
	if err != nil {
		fmt.Printf("Error seeking: %v\n", err)
		os.Exit(1)
	}

	if offset != int64(pageNum*PAGE_SIZE) {
		fmt.Printf("Seek offset does not match page start\n")
		os.Exit(1)
	}

	_, err = pager.fileDescriptor.Write(pager.pages[pageNum][:size])
	if err != nil {
		fmt.Printf("Error writing: %v\n", err)
		os.Exit(1)
	}
}

func dbClose(table *Table) {
	pager := table.pager
	numFullPages := table.numRows / ROWS_PER_PAGE

	for i := uint32(0); i < numFullPages; i++ {
		if pager.pages[i] == nil {
			continue
		}
		pagerFlush(pager, i, PAGE_SIZE)
		pager.pages[i] = nil
	}

	numAdditionalRows := table.numRows % ROWS_PER_PAGE
	if numAdditionalRows > 0 {
		pageNum := numFullPages
		if pager.pages[pageNum] != nil {
			pagerFlush(pager, pageNum, numAdditionalRows*ROW_SIZE)
			pager.pages[pageNum] = nil
		}
	}

	err := pager.fileDescriptor.Close()
	if err != nil {
		fmt.Printf("Error closing db file: %v\n", err)
		os.Exit(1)
	}

	for i := 0; i < TABLE_MAX_PAGES; i++ {
		page := pager.pages[i]
		if page != nil {
			pager.pages[i] = nil
		}
	}

	os.Exit(0)
}

func printPrompt() {
	fmt.Print("db > ")
}

func readInput(reader *bufio.Reader, inputBuffer *InputBuffer) {
	// chatGPT init error, need to debug
	//reader := bufio.NewReader(os.Stdin)
	buffer, err := reader.ReadString('\n')
	if err != nil {
		fmt.Println("Error reading input: ", err.Error())
		os.Exit(1)
	}

	// Ignore newline character
	buffer = buffer[:len(buffer)-1]
	inputBuffer.inputLength = len(buffer)
	inputBuffer.buffer = buffer
}

func closeInputBuffer(inputBuffer *InputBuffer) {
	inputBuffer.buffer = ""
}

func doMetaCommand(inputBuffer *InputBuffer, table *Table) MetaCommandResult {
	if inputBuffer.buffer == ".exit" {
		closeInputBuffer(inputBuffer)
		dbClose(table)
		return META_COMMAND_SUCCESS
	} else {
		return META_COMMAND_UNRECOGNIZED_COMMAND
	}
}

func prepareInsert(inputBuffer *InputBuffer, statement *Statement) PrepareResult {
	statement.typ = STATEMENT_INSERT

	tokens := strings.Fields(inputBuffer.buffer)
	if len(tokens) != 4 {
		return PREPARE_SYNTAX_ERROR
	}

	id, err := strconv.Atoi(tokens[1])
	if err != nil {
		return PREPARE_NEGATIVE_ID
	}

	if id < 0 {
		return PREPARE_NEGATIVE_ID
	}

	if len(tokens[2]) > COLUMN_USERNAME_SIZE || len(tokens[3]) > COLUMN_EMAIL_SIZE {
		return PREPARE_STRING_TOO_LONG
	}

	statement.rowToInsert.id = uint32(id)
	copy(statement.rowToInsert.username[:], tokens[2])
	copy(statement.rowToInsert.email[:], tokens[3])

	return PREPARE_SUCCESS
}

func prepareStatement(inputBuffer *InputBuffer, statement *Statement) PrepareResult {
	tokens := strings.Fields(inputBuffer.buffer)

	if len(tokens) == 0 {
		return PREPARE_UNRECOGNIZED_STATEMENT
	}

	switch tokens[0] {
	case "insert":
		return prepareInsert(inputBuffer, statement)
	case "select":
		statement.typ = STATEMENT_SELECT
		return PREPARE_SUCCESS
	default:
		return PREPARE_UNRECOGNIZED_STATEMENT
	}
}

func executeInsert(statement *Statement, table *Table) ExecuteResult {
	if table.numRows >= TABLE_MAX_ROWS {
		return EXECUTE_TABLE_FULL
	}

	rowToInsert := &statement.rowToInsert

	serializeRow(rowToInsert, rowSlot(table, table.numRows))
	table.numRows++

	return EXECUTE_SUCCESS
}

func executeSelect(statement *Statement, table *Table) ExecuteResult {
	var row Row
	for i := uint32(0); i < table.numRows; i++ {
		deserializeRow(rowSlot(table, i), &row)
		printRow(&row)
	}
	return EXECUTE_SUCCESS
}

func executeStatement(statement *Statement, table *Table) ExecuteResult {
	switch statement.typ {
	case STATEMENT_INSERT:
		return executeInsert(statement, table)
	case STATEMENT_SELECT:
		return executeSelect(statement, table)
	default:
		return EXECUTE_SUCCESS
	}
}

func main() {
	if len(os.Args) < 2 {
		fmt.Println("Must supply a database filename.")
		os.Exit(1)
	}

	filename := os.Args[1]
	table := dbOpen(filename)

	inputBuffer := newInputBuffer()
	reader := bufio.NewReader(os.Stdin)
	for {
		printPrompt()
		readInput(reader, inputBuffer)

		if inputBuffer.buffer[0] == '.' {
			switch doMetaCommand(inputBuffer, table) {
			case META_COMMAND_SUCCESS:
				continue
			case META_COMMAND_UNRECOGNIZED_COMMAND:
				fmt.Printf("Unrecognized command '%s'\n", inputBuffer.buffer)
				continue
			}
		}

		var statement Statement
		switch prepareStatement(inputBuffer, &statement) {
		case PREPARE_SUCCESS:
			break
		case PREPARE_NEGATIVE_ID:
			fmt.Println("ID must be positive.")
			continue
		case PREPARE_STRING_TOO_LONG:
			fmt.Println("String is too long.")
			continue
		case PREPARE_SYNTAX_ERROR:
			fmt.Println("Syntax error. Could not parse statement.")
			continue
		case PREPARE_UNRECOGNIZED_STATEMENT:
			fmt.Printf("Unrecognized keyword at start of '%s'.\n", inputBuffer.buffer)
			continue
		}

		switch executeStatement(&statement, table) {
		case EXECUTE_SUCCESS:
			fmt.Println("Executed.")
		case EXECUTE_TABLE_FULL:
			fmt.Println("Error: Table full.")
		}
	}
}

运行:

go run golang/5.go test.db       
db > select
Executed.
db > insert 1 ab cd
Executed.
db > select
(1, ab, cd)
Executed.
db > .exit
go run golang/5.go test.db
db > select
(1, ab, cd)
Executed.
db > .exit

第六部分 - 游标抽象

golang实现

package main

import (
	"bufio"
	"fmt"
	"io"
	"os"
	"strconv"
	"strings"
	"unsafe"
)

const (
	COLUMN_USERNAME_SIZE = 32
	COLUMN_EMAIL_SIZE    = 255
	ID_SIZE              = 4
	USERNAME_SIZE        = COLUMN_USERNAME_SIZE + 1
	EMAIL_SIZE           = COLUMN_EMAIL_SIZE + 1
	ID_OFFSET            = 0
	USERNAME_OFFSET      = ID_OFFSET + ID_SIZE
	EMAIL_OFFSET         = USERNAME_OFFSET + USERNAME_SIZE
	ROW_SIZE             = ID_SIZE + USERNAME_SIZE + EMAIL_SIZE
	PAGE_SIZE            = 4096
	TABLE_MAX_PAGES      = 100
	ROWS_PER_PAGE        = PAGE_SIZE / ROW_SIZE
	TABLE_MAX_ROWS       = ROWS_PER_PAGE * TABLE_MAX_PAGES
)

type InputBuffer struct {
	buffer       string
	bufferLength int
	inputLength  int
}

type MetaCommandResult int

const (
	META_COMMAND_SUCCESS MetaCommandResult = iota
	META_COMMAND_UNRECOGNIZED_COMMAND
)

type PrepareResult int

const (
	PREPARE_SUCCESS PrepareResult = iota
	PREPARE_NEGATIVE_ID
	PREPARE_STRING_TOO_LONG
	PREPARE_SYNTAX_ERROR
	PREPARE_UNRECOGNIZED_STATEMENT
)

type StatementType int

const (
	STATEMENT_INSERT StatementType = iota
	STATEMENT_SELECT
)

type Row struct {
	id       uint32
	username [COLUMN_USERNAME_SIZE + 1]byte
	email    [COLUMN_EMAIL_SIZE + 1]byte
}

type Statement struct {
	typ         StatementType
	rowToInsert Row
}

type Pager struct {
	fileDescriptor *os.File
	fileLength     uint32
	pages          [TABLE_MAX_PAGES][]byte
}

type Table struct {
	numRows uint32
	pager   *Pager
}

type Cursor struct {
	table      *Table
	rowNum     uint32
	endOfTable bool // 表示最后一个元素之后的位置
}

type ExecuteResult int

const (
	EXECUTE_SUCCESS ExecuteResult = iota
	EXECUTE_TABLE_FULL
)

func newInputBuffer() *InputBuffer {
	buffer := ""
	return &InputBuffer{
		buffer:       buffer,
		bufferLength: 0,
		inputLength:  0,
	}
}

func printRow(row *Row) {
	//fmt.Printf("(%d, %s, %s)\n", row.id, row.username, row.email)
	fmt.Printf("(%d, %s, %s)\n", row.id, strings.TrimRight(string(row.username[:]), "\x00"), strings.TrimRight(string(row.email[:]), "\x00"))
}

func serializeRow(source *Row, destination []byte) {
	copy(destination[ID_OFFSET:], (*(*[ID_SIZE]byte)(unsafe.Pointer(&source.id)))[:])
	copy(destination[USERNAME_OFFSET:], source.username[:])
	copy(destination[EMAIL_OFFSET:], source.email[:])
}

func deserializeRow(source []byte, destination *Row) {
	destination.id = *(*uint32)(unsafe.Pointer(&source[ID_OFFSET]))
	copy(destination.username[:], source[USERNAME_OFFSET:USERNAME_OFFSET+USERNAME_SIZE])
	copy(destination.email[:], source[EMAIL_OFFSET:EMAIL_OFFSET+EMAIL_SIZE])
}

func getPage(pager *Pager, pageNum uint32) []byte {
	if pageNum > TABLE_MAX_PAGES {
		fmt.Printf("Tried to fetch page number out of bounds. %d > %d\n", pageNum, TABLE_MAX_PAGES)
		os.Exit(1)
	}

	if pager.pages[pageNum] == nil {
		page := make([]byte, PAGE_SIZE)
		numPages := pager.fileLength / PAGE_SIZE

		if pager.fileLength%PAGE_SIZE != 0 {
			numPages++
		}

		if pageNum <= numPages {
			_, err := pager.fileDescriptor.Seek(int64(pageNum*PAGE_SIZE), os.SEEK_SET)
			if err != nil {
				fmt.Printf("Error seeking: %v\n", err)
				os.Exit(1)
			}

			_, err = pager.fileDescriptor.Read(page)
			if err != nil && err != io.EOF {
				fmt.Printf("Error reading file: %v\n", err)
				os.Exit(1)
			}
		}

		pager.pages[pageNum] = page
	}

	return pager.pages[pageNum]
}

func tableStart(table *Table) *Cursor {
	cursor := &Cursor{
		table:      table,
		rowNum:     0,
		endOfTable: (table.numRows == 0),
	}
	return cursor
}

func tableEnd(table *Table) *Cursor {
	cursor := &Cursor{
		table:      table,
		rowNum:     table.numRows,
		endOfTable: true,
	}
	return cursor
}

func cursorValue(cursor *Cursor) []byte {
	rowNum := cursor.rowNum
	pageNum := rowNum / ROWS_PER_PAGE
	page := getPage(cursor.table.pager, pageNum)
	rowOffset := rowNum % ROWS_PER_PAGE
	byteOffset := rowOffset * ROW_SIZE
	return page[byteOffset : byteOffset+ROW_SIZE]
}

func cursorAdvance(cursor *Cursor) {
	cursor.rowNum += 1
	if cursor.rowNum >= cursor.table.numRows {
		cursor.endOfTable = true
	}
}

func pagerOpen(filename string) *Pager {
	fileDescriptor, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE, 0644)
	if err != nil {
		fmt.Printf("Unable to open file: %v\n", err)
		os.Exit(1)
	}

	fileLength, err := fileDescriptor.Seek(0, os.SEEK_END)
	if err != nil {
		fmt.Printf("Error seeking: %v\n", err)
		os.Exit(1)
	}

	pager := &Pager{
		fileDescriptor: fileDescriptor,
		fileLength:     uint32(fileLength),
	}

	for i := 0; i < TABLE_MAX_PAGES; i++ {
		pager.pages[i] = nil
	}

	return pager
}

func dbOpen(filename string) *Table {
	pager := pagerOpen(filename)
	numRows := pager.fileLength / ROW_SIZE

	table := &Table{
		numRows: numRows,
		pager:   pager,
	}

	return table
}

func pagerFlush(pager *Pager, pageNum uint32, size uint32) {
	if pager.pages[pageNum] == nil {
		fmt.Printf("Tried to flush null page\n")
		os.Exit(1)
	}

	offset, err := pager.fileDescriptor.Seek(int64(pageNum*PAGE_SIZE), os.SEEK_SET)
	if err != nil {
		fmt.Printf("Error seeking: %v\n", err)
		os.Exit(1)
	}

	if offset != int64(pageNum*PAGE_SIZE) {
		fmt.Printf("Seek offset does not match page start\n")
		os.Exit(1)
	}

	_, err = pager.fileDescriptor.Write(pager.pages[pageNum][:size])
	if err != nil {
		fmt.Printf("Error writing: %v\n", err)
		os.Exit(1)
	}
}

func dbClose(table *Table) {
	pager := table.pager
	numFullPages := table.numRows / ROWS_PER_PAGE

	for i := uint32(0); i < numFullPages; i++ {
		if pager.pages[i] == nil {
			continue
		}
		pagerFlush(pager, i, PAGE_SIZE)
		pager.pages[i] = nil
	}

	numAdditionalRows := table.numRows % ROWS_PER_PAGE
	if numAdditionalRows > 0 {
		pageNum := numFullPages
		if pager.pages[pageNum] != nil {
			pagerFlush(pager, pageNum, numAdditionalRows*ROW_SIZE)
			pager.pages[pageNum] = nil
		}
	}

	err := pager.fileDescriptor.Close()
	if err != nil {
		fmt.Printf("Error closing db file: %v\n", err)
		os.Exit(1)
	}

	for i := 0; i < TABLE_MAX_PAGES; i++ {
		page := pager.pages[i]
		if page != nil {
			pager.pages[i] = nil
		}
	}

	os.Exit(0)
}

func printPrompt() {
	fmt.Print("db > ")
}

func readInput(reader *bufio.Reader, inputBuffer *InputBuffer) {
	// chatGPT init error, need to debug
	//reader := bufio.NewReader(os.Stdin)
	buffer, err := reader.ReadString('\n')
	if err != nil {
		fmt.Println("Error reading input: ", err.Error())
		os.Exit(1)
	}

	// Ignore newline character
	buffer = buffer[:len(buffer)-1]
	inputBuffer.inputLength = len(buffer)
	inputBuffer.buffer = buffer
}

func closeInputBuffer(inputBuffer *InputBuffer) {
	inputBuffer.buffer = ""
}

func doMetaCommand(inputBuffer *InputBuffer, table *Table) MetaCommandResult {
	if inputBuffer.buffer == ".exit" {
		closeInputBuffer(inputBuffer)
		dbClose(table)
		return META_COMMAND_SUCCESS
	} else {
		return META_COMMAND_UNRECOGNIZED_COMMAND
	}
}

func prepareInsert(inputBuffer *InputBuffer, statement *Statement) PrepareResult {
	statement.typ = STATEMENT_INSERT

	tokens := strings.Fields(inputBuffer.buffer)
	if len(tokens) != 4 {
		return PREPARE_SYNTAX_ERROR
	}

	id, err := strconv.Atoi(tokens[1])
	if err != nil {
		return PREPARE_NEGATIVE_ID
	}

	if id < 0 {
		return PREPARE_NEGATIVE_ID
	}

	if len(tokens[2]) > COLUMN_USERNAME_SIZE || len(tokens[3]) > COLUMN_EMAIL_SIZE {
		return PREPARE_STRING_TOO_LONG
	}

	statement.rowToInsert.id = uint32(id)
	copy(statement.rowToInsert.username[:], tokens[2])
	copy(statement.rowToInsert.email[:], tokens[3])

	return PREPARE_SUCCESS
}

func prepareStatement(inputBuffer *InputBuffer, statement *Statement) PrepareResult {
	tokens := strings.Fields(inputBuffer.buffer)

	if len(tokens) == 0 {
		return PREPARE_UNRECOGNIZED_STATEMENT
	}

	switch tokens[0] {
	case "insert":
		return prepareInsert(inputBuffer, statement)
	case "select":
		statement.typ = STATEMENT_SELECT
		return PREPARE_SUCCESS
	default:
		return PREPARE_UNRECOGNIZED_STATEMENT
	}
}

func executeInsert(statement *Statement, table *Table) ExecuteResult {
	if table.numRows >= TABLE_MAX_ROWS {
		return EXECUTE_TABLE_FULL
	}

	rowToInsert := &statement.rowToInsert
	cursor := tableEnd(table)
	serializeRow(rowToInsert, cursorValue(cursor))

	table.numRows++

	return EXECUTE_SUCCESS
}

func executeSelect(statement *Statement, table *Table) ExecuteResult {
	cursor := tableStart(table)
	var row Row
	for cursor.endOfTable == false {
		deserializeRow(cursorValue(cursor), &row)
		printRow(&row)
		cursorAdvance(cursor)
	}
	return EXECUTE_SUCCESS
}

func executeStatement(statement *Statement, table *Table) ExecuteResult {
	switch statement.typ {
	case STATEMENT_INSERT:
		return executeInsert(statement, table)
	case STATEMENT_SELECT:
		return executeSelect(statement, table)
	default:
		return EXECUTE_SUCCESS
	}
}

func main() {
	if len(os.Args) < 2 {
		fmt.Println("Must supply a database filename.")
		os.Exit(1)
	}

	filename := os.Args[1]
	table := dbOpen(filename)

	inputBuffer := newInputBuffer()
	reader := bufio.NewReader(os.Stdin)
	for {
		printPrompt()
		readInput(reader, inputBuffer)

		if inputBuffer.buffer[0] == '.' {
			switch doMetaCommand(inputBuffer, table) {
			case META_COMMAND_SUCCESS:
				continue
			case META_COMMAND_UNRECOGNIZED_COMMAND:
				fmt.Printf("Unrecognized command '%s'\n", inputBuffer.buffer)
				continue
			}
		}

		var statement Statement
		switch prepareStatement(inputBuffer, &statement) {
		case PREPARE_SUCCESS:
			break
		case PREPARE_NEGATIVE_ID:
			fmt.Println("ID must be positive.")
			continue
		case PREPARE_STRING_TOO_LONG:
			fmt.Println("String is too long.")
			continue
		case PREPARE_SYNTAX_ERROR:
			fmt.Println("Syntax error. Could not parse statement.")
			continue
		case PREPARE_UNRECOGNIZED_STATEMENT:
			fmt.Printf("Unrecognized keyword at start of '%s'.\n", inputBuffer.buffer)
			continue
		}

		switch executeStatement(&statement, table) {
		case EXECUTE_SUCCESS:
			fmt.Println("Executed.")
		case EXECUTE_TABLE_FULL:
			fmt.Println("Error: Table full.")
		}
	}
}

使用原来章节的测试用例

第八部分 - B-Tree叶节点格式

python 测试用例:

import sys,os
from util import run_script

# 测试btree结构
def test_one_node_btree_structure(db_file=""):
    script = [
        "insert 3 user3 person3@example.com",
        "insert 1 user1 person1@example.com",
        "insert 2 user2 person2@example.com",
        ".btree",
        ".exit"
    ]
    
    result = run_script(script,db_file=db_file)

    expected_result = [
        "db > Executed.",
        "db > Executed.",
        "db > Executed.",
        "db > Tree:",
        "leaf (size 3)",
        "  - 0 : 3",
        "  - 1 : 1",
        "  - 2 : 2",
        "db > "
    ]

    print(f"result: {result}")
    assert result == expected_result
    print(f"{sys._getframe().f_code.co_name} passed")


# 测试打印常量
def test_print_constants(db_file=""):
    script = [
        ".constants",
        ".exit",
    ]
    result = run_script(script,db_file=db_file)

    expected_result = [
        "db > Constants:",
        "ROW_SIZE: 293",
        "COMMON_NODE_HEADER_SIZE: 6",
        "LEAF_NODE_HEADER_SIZE: 10",
        "LEAF_NODE_CELL_SIZE: 297",
        "LEAF_NODE_SPACE_FOR_CELLS: 4086",
        "LEAF_NODE_MAX_CELLS: 13",
        "db > ",
    ]

    assert result == expected_result
    print(f"{sys._getframe().f_code.co_name} passed")

if len(sys.argv)<2:
    print(f"need db file path")
    exit(0)
db_file = sys.argv[1]
if os.path.exists(db_file):
    os.remove(db_file)

test_one_node_btree_structure(db_file)
test_print_constants(db_file)

print("all tests passed.")

golang实现 (无序方式存储键)

package main

import (
	"bufio"
	"fmt"
	"io"
	"os"
	"strconv"
	"strings"
	"unsafe"
)

const (
	COLUMN_USERNAME_SIZE = 32
	COLUMN_EMAIL_SIZE    = 255
	ID_SIZE              = 4
	USERNAME_SIZE        = COLUMN_USERNAME_SIZE + 1
	EMAIL_SIZE           = COLUMN_EMAIL_SIZE + 1
	ID_OFFSET            = 0
	USERNAME_OFFSET      = ID_OFFSET + ID_SIZE
	EMAIL_OFFSET         = USERNAME_OFFSET + USERNAME_SIZE
	ROW_SIZE             = ID_SIZE + USERNAME_SIZE + EMAIL_SIZE
	PAGE_SIZE            = 4096
	TABLE_MAX_PAGES      = 100
)

type NodeType uint8

const (
	NODE_INTERNAL NodeType = iota
	NODE_LEAF
)

// Common Node Header Layout
const (
	NODE_TYPE_SIZE          = 1
	NODE_TYPE_OFFSET        = 0
	IS_ROOT_SIZE            = 1
	IS_ROOT_OFFSET          = NODE_TYPE_OFFSET + NODE_TYPE_SIZE
	PARENT_POINTER_SIZE     = 4
	PARENT_POINTER_OFFSET   = IS_ROOT_OFFSET + IS_ROOT_SIZE
	COMMON_NODE_HEADER_SIZE = NODE_TYPE_SIZE + IS_ROOT_SIZE + PARENT_POINTER_SIZE
)

// Leaf Node Header Layout
const (
	LEAF_NODE_NUM_CELLS_SIZE   = 4
	LEAF_NODE_NUM_CELLS_OFFSET = COMMON_NODE_HEADER_SIZE
	LEAF_NODE_HEADER_SIZE      = COMMON_NODE_HEADER_SIZE + LEAF_NODE_NUM_CELLS_SIZE
)

// Leaf Node Body Layout
const (
	LEAF_NODE_KEY_SIZE        = 4
	LEAF_NODE_KEY_OFFSET      = 0
	LEAF_NODE_VALUE_SIZE      = ROW_SIZE
	LEAF_NODE_VALUE_OFFSET    = LEAF_NODE_KEY_OFFSET + LEAF_NODE_KEY_SIZE
	LEAF_NODE_CELL_SIZE       = LEAF_NODE_KEY_SIZE + LEAF_NODE_VALUE_SIZE
	LEAF_NODE_SPACE_FOR_CELLS = PAGE_SIZE - LEAF_NODE_HEADER_SIZE
	LEAF_NODE_MAX_CELLS       = LEAF_NODE_SPACE_FOR_CELLS / LEAF_NODE_CELL_SIZE
)

type InputBuffer struct {
	buffer       string
	bufferLength int
	inputLength  int
}

type MetaCommandResult int

const (
	META_COMMAND_SUCCESS MetaCommandResult = iota
	META_COMMAND_UNRECOGNIZED_COMMAND
)

type PrepareResult int

const (
	PREPARE_SUCCESS PrepareResult = iota
	PREPARE_NEGATIVE_ID
	PREPARE_STRING_TOO_LONG
	PREPARE_SYNTAX_ERROR
	PREPARE_UNRECOGNIZED_STATEMENT
)

type StatementType int

const (
	STATEMENT_INSERT StatementType = iota
	STATEMENT_SELECT
)

type Row struct {
	id       uint32
	username [COLUMN_USERNAME_SIZE + 1]byte
	email    [COLUMN_EMAIL_SIZE + 1]byte
}

type Statement struct {
	typ         StatementType
	rowToInsert Row
}

type Pager struct {
	fileDescriptor *os.File
	fileLength     uint32
	numPages       uint32
	pages          [TABLE_MAX_PAGES][]byte
}

type Table struct {
	rootPageNum uint32
	pager       *Pager
}

type Cursor struct {
	table      *Table
	pageNum    uint32
	cellNum    uint32
	endOfTable bool // 表示最后一个元素之后的位置
}

type ExecuteResult int

const (
	EXECUTE_SUCCESS ExecuteResult = iota
	EXECUTE_TABLE_FULL
)

func newInputBuffer() *InputBuffer {
	buffer := ""
	return &InputBuffer{
		buffer:       buffer,
		bufferLength: 0,
		inputLength:  0,
	}
}

func leafNodeNumCells(node []byte) *uint32 {
	return (*uint32)(unsafe.Pointer(&node[LEAF_NODE_NUM_CELLS_OFFSET]))
}

func leafNodeCell(node []byte, cellNum uint32) []byte {
	offset := LEAF_NODE_HEADER_SIZE + cellNum*LEAF_NODE_CELL_SIZE
	return node[offset : offset+LEAF_NODE_CELL_SIZE]
}

func leafNodeKey(node []byte, cellNum uint32) *uint32 {
	offset := LEAF_NODE_HEADER_SIZE + cellNum*LEAF_NODE_CELL_SIZE
	return (*uint32)(unsafe.Pointer(&node[offset]))
}

func leafNodeValue(node []byte, cellNum uint32) []byte {
	offset := LEAF_NODE_HEADER_SIZE + cellNum*LEAF_NODE_CELL_SIZE + LEAF_NODE_KEY_SIZE
	return node[offset : offset+LEAF_NODE_VALUE_SIZE]
}

func printConstants() {
	fmt.Printf("ROW_SIZE: %d\n", ROW_SIZE)
	fmt.Printf("COMMON_NODE_HEADER_SIZE: %d\n", COMMON_NODE_HEADER_SIZE)
	fmt.Printf("LEAF_NODE_HEADER_SIZE: %d\n", LEAF_NODE_HEADER_SIZE)
	fmt.Printf("LEAF_NODE_CELL_SIZE: %d\n", LEAF_NODE_CELL_SIZE)
	fmt.Printf("LEAF_NODE_SPACE_FOR_CELLS: %d\n", LEAF_NODE_SPACE_FOR_CELLS)
	fmt.Printf("LEAF_NODE_MAX_CELLS: %d\n", LEAF_NODE_MAX_CELLS)
}

func printLeafNode(node []byte) {
	numCells := *leafNodeNumCells(node)
	fmt.Printf("leaf (size %d)\n", numCells)
	for i := uint32(0); i < numCells; i++ {
		key := *leafNodeKey(node, i)
		fmt.Printf("  - %d : %d\n", i, key)
	}
}

func printRow(row *Row) {
	//fmt.Printf("(%d, %s, %s)\n", row.id, row.username, row.email)
	fmt.Printf("(%d, %s, %s)\n", row.id, strings.TrimRight(string(row.username[:]), "\x00"), strings.TrimRight(string(row.email[:]), "\x00"))
}

func serializeRow(source *Row, destination []byte) {
	copy(destination[ID_OFFSET:], (*(*[ID_SIZE]byte)(unsafe.Pointer(&source.id)))[:])
	copy(destination[USERNAME_OFFSET:], source.username[:])
	copy(destination[EMAIL_OFFSET:], source.email[:])
}

func deserializeRow(source []byte, destination *Row) {
	destination.id = *(*uint32)(unsafe.Pointer(&source[ID_OFFSET]))
	copy(destination.username[:], source[USERNAME_OFFSET:USERNAME_OFFSET+USERNAME_SIZE])
	copy(destination.email[:], source[EMAIL_OFFSET:EMAIL_OFFSET+EMAIL_SIZE])
}

func initializeLeafNode(node []byte) {
	*leafNodeNumCells(node) = 0
}

func getPage(pager *Pager, pageNum uint32) []byte {
	if pageNum > TABLE_MAX_PAGES {
		fmt.Printf("Tried to fetch page number out of bounds. %d > %d\n", pageNum, TABLE_MAX_PAGES)
		os.Exit(1)
	}

	if pager.pages[pageNum] == nil {
		page := make([]byte, PAGE_SIZE)
		numPages := pager.fileLength / PAGE_SIZE

		if pager.fileLength%PAGE_SIZE != 0 {
			numPages++
		}

		if pageNum <= numPages {
			_, err := pager.fileDescriptor.Seek(int64(pageNum*PAGE_SIZE), os.SEEK_SET)
			if err != nil {
				fmt.Printf("Error seeking: %v\n", err)
				os.Exit(1)
			}

			_, err = pager.fileDescriptor.Read(page)
			if err != nil && err != io.EOF {
				fmt.Printf("Error reading file: %v\n", err)
				os.Exit(1)
			}
		}

		pager.pages[pageNum] = page
		if pageNum >= pager.numPages {
			pager.numPages = pageNum + 1
		}
	}

	return pager.pages[pageNum]
}

func tableStart(table *Table) *Cursor {
	cursor := &Cursor{
		table:      table,
		pageNum:    table.rootPageNum,
		cellNum:    0,
		endOfTable: false,
	}

	rootNode := getPage(table.pager, table.rootPageNum)
	numCells := *leafNodeNumCells(rootNode)
	cursor.endOfTable = (numCells == 0)

	return cursor
}

func tableEnd(table *Table) *Cursor {
	cursor := &Cursor{
		table:      table,
		endOfTable: true,
		pageNum:    table.rootPageNum,
	}

	rootNode := getPage(table.pager, table.rootPageNum)
	numCells := *leafNodeNumCells(rootNode)
	cursor.cellNum = numCells

	return cursor
}

func cursorValue(cursor *Cursor) []byte {
	pageNum := cursor.pageNum
	page := getPage(cursor.table.pager, pageNum)
	return leafNodeValue(page, cursor.cellNum)
}

func cursorAdvance(cursor *Cursor) {
	pageNum := cursor.pageNum
	node := getPage(cursor.table.pager, pageNum)
	cursor.cellNum += 1
	if cursor.cellNum >= *leafNodeNumCells(node) {
		cursor.endOfTable = true
	}
}

func pagerOpen(filename string) *Pager {
	fileDescriptor, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE, 0644)
	if err != nil {
		fmt.Printf("Unable to open file: %v\n", err)
		os.Exit(1)
	}

	fileLength, err := fileDescriptor.Seek(0, os.SEEK_END)
	if err != nil {
		fmt.Printf("Error seeking: %v\n", err)
		os.Exit(1)
	}

	pager := &Pager{
		fileDescriptor: fileDescriptor,
		fileLength:     uint32(fileLength),
		numPages:       uint32(fileLength / PAGE_SIZE),
	}

	if fileLength%PAGE_SIZE != 0 {
		fmt.Printf("Db file is not a whole number of pages. Corrupt file.\n")
		os.Exit(1)
	}

	for i := 0; i < TABLE_MAX_PAGES; i++ {
		pager.pages[i] = nil
	}

	return pager
}

func dbOpen(filename string) *Table {
	pager := pagerOpen(filename)

	table := &Table{
		rootPageNum: 0,
		pager:       pager,
	}

	if pager.numPages == 0 {
		// New database file. Initialize page 0 as leaf node.
		rootNode := getPage(pager, 0)
		initializeLeafNode(rootNode)
	}

	return table
}

func pagerFlush(pager *Pager, pageNum uint32) {
	if pager.pages[pageNum] == nil {
		fmt.Printf("Tried to flush null page\n")
		os.Exit(1)
	}

	offset, err := pager.fileDescriptor.Seek(int64(pageNum*PAGE_SIZE), os.SEEK_SET)
	if err != nil {
		fmt.Printf("Error seeking: %v\n", err)
		os.Exit(1)
	}

	if offset != int64(pageNum*PAGE_SIZE) {
		fmt.Printf("Seek offset does not match page start\n")
		os.Exit(1)
	}

	_, err = pager.fileDescriptor.Write(pager.pages[pageNum][:PAGE_SIZE])
	if err != nil {
		fmt.Printf("Error writing: %v\n", err)
		os.Exit(1)
	}
}

func dbClose(table *Table) {
	pager := table.pager

	for i := uint32(0); i < pager.numPages; i++ {
		if pager.pages[i] == nil {
			continue
		}
		pagerFlush(pager, i)
		pager.pages[i] = nil
	}

	err := pager.fileDescriptor.Close()
	if err != nil {
		fmt.Printf("Error closing db file: %v\n", err)
		os.Exit(1)
	}

	for i := 0; i < TABLE_MAX_PAGES; i++ {
		page := pager.pages[i]
		if page != nil {
			pager.pages[i] = nil
		}
	}

	os.Exit(0)
}

func printPrompt() {
	fmt.Print("db > ")
}

func readInput(reader *bufio.Reader, inputBuffer *InputBuffer) {
	// chatGPT init error, need to debug
	//reader := bufio.NewReader(os.Stdin)
	buffer, err := reader.ReadString('\n')
	if err != nil {
		fmt.Println("Error reading input: ", err.Error())
		os.Exit(1)
	}

	// Ignore newline character
	buffer = buffer[:len(buffer)-1]
	inputBuffer.inputLength = len(buffer)
	inputBuffer.buffer = buffer
}

func closeInputBuffer(inputBuffer *InputBuffer) {
	inputBuffer.buffer = ""
}

func doMetaCommand(inputBuffer *InputBuffer, table *Table) MetaCommandResult {
	if inputBuffer.buffer == ".exit" {
		closeInputBuffer(inputBuffer)
		dbClose(table)
		return META_COMMAND_SUCCESS
	} else if inputBuffer.buffer == ".btree" {
		fmt.Printf(("Tree:\n"))
		printLeafNode(getPage(table.pager, 0))
		return META_COMMAND_SUCCESS
	} else if inputBuffer.buffer == ".constants" {
		fmt.Printf(("Constants:\n"))
		printConstants()
		return META_COMMAND_SUCCESS
	} else {
		return META_COMMAND_UNRECOGNIZED_COMMAND
	}
}

func prepareInsert(inputBuffer *InputBuffer, statement *Statement) PrepareResult {
	statement.typ = STATEMENT_INSERT

	tokens := strings.Fields(inputBuffer.buffer)
	if len(tokens) != 4 {
		return PREPARE_SYNTAX_ERROR
	}

	id, err := strconv.Atoi(tokens[1])
	if err != nil {
		return PREPARE_NEGATIVE_ID
	}

	if id < 0 {
		return PREPARE_NEGATIVE_ID
	}

	if len(tokens[2]) > COLUMN_USERNAME_SIZE || len(tokens[3]) > COLUMN_EMAIL_SIZE {
		return PREPARE_STRING_TOO_LONG
	}

	statement.rowToInsert.id = uint32(id)
	copy(statement.rowToInsert.username[:], tokens[2])
	copy(statement.rowToInsert.email[:], tokens[3])

	return PREPARE_SUCCESS
}

func prepareStatement(inputBuffer *InputBuffer, statement *Statement) PrepareResult {
	tokens := strings.Fields(inputBuffer.buffer)

	if len(tokens) == 0 {
		return PREPARE_UNRECOGNIZED_STATEMENT
	}

	switch tokens[0] {
	case "insert":
		return prepareInsert(inputBuffer, statement)
	case "select":
		statement.typ = STATEMENT_SELECT
		return PREPARE_SUCCESS
	default:
		return PREPARE_UNRECOGNIZED_STATEMENT
	}
}

func leafNodeInsert(cursor *Cursor, key uint32, value *Row) {
	node := getPage(cursor.table.pager, cursor.pageNum)

	numCells := *leafNodeNumCells(node)
	if numCells >= LEAF_NODE_MAX_CELLS {
		// Node full
		fmt.Println("Need to implement splitting a leaf node.")
		os.Exit(1)
	}

	if cursor.cellNum < numCells {
		// Make room for new cell
		for i := numCells; i > cursor.cellNum; i-- {
			copy(leafNodeCell(node, i), leafNodeCell(node, i-1))
		}
	}

	*leafNodeNumCells(node) += 1
	*leafNodeKey(node, cursor.cellNum) = key
	serializeRow(value, leafNodeValue(node, cursor.cellNum))
}

func executeInsert(statement *Statement, table *Table) ExecuteResult {
	node := getPage(table.pager, table.rootPageNum)
	if *leafNodeNumCells(node) >= LEAF_NODE_MAX_CELLS {
		return EXECUTE_TABLE_FULL
	}

	rowToInsert := &statement.rowToInsert
	cursor := tableEnd(table)
	leafNodeInsert(cursor, rowToInsert.id, rowToInsert)

	return EXECUTE_SUCCESS
}

func executeSelect(statement *Statement, table *Table) ExecuteResult {
	cursor := tableStart(table)
	var row Row
	for cursor.endOfTable == false {
		deserializeRow(cursorValue(cursor), &row)
		printRow(&row)
		cursorAdvance(cursor)
	}
	return EXECUTE_SUCCESS
}

func executeStatement(statement *Statement, table *Table) ExecuteResult {
	switch statement.typ {
	case STATEMENT_INSERT:
		return executeInsert(statement, table)
	case STATEMENT_SELECT:
		return executeSelect(statement, table)
	default:
		return EXECUTE_SUCCESS
	}
}

func main() {
	if len(os.Args) < 2 {
		fmt.Println("Must supply a database filename.")
		os.Exit(1)
	}

	filename := os.Args[1]
	table := dbOpen(filename)

	inputBuffer := newInputBuffer()
	reader := bufio.NewReader(os.Stdin)
	for {
		printPrompt()
		readInput(reader, inputBuffer)

		if inputBuffer.buffer[0] == '.' {
			switch doMetaCommand(inputBuffer, table) {
			case META_COMMAND_SUCCESS:
				continue
			case META_COMMAND_UNRECOGNIZED_COMMAND:
				fmt.Printf("Unrecognized command '%s'\n", inputBuffer.buffer)
				continue
			}
		}

		var statement Statement
		switch prepareStatement(inputBuffer, &statement) {
		case PREPARE_SUCCESS:
			break
		case PREPARE_NEGATIVE_ID:
			fmt.Println("ID must be positive.")
			continue
		case PREPARE_STRING_TOO_LONG:
			fmt.Println("String is too long.")
			continue
		case PREPARE_SYNTAX_ERROR:
			fmt.Println("Syntax error. Could not parse statement.")
			continue
		case PREPARE_UNRECOGNIZED_STATEMENT:
			fmt.Printf("Unrecognized keyword at start of '%s'.\n", inputBuffer.buffer)
			continue
		}

		switch executeStatement(&statement, table) {
		case EXECUTE_SUCCESS:
			fmt.Println("Executed.")
		case EXECUTE_TABLE_FULL:
			fmt.Println("Error: Table full.")
		}
	}
}

第九部分 - 二分搜索和重复键

golang实现

package main

import (
	"bufio"
	"fmt"
	"io"
	"os"
	"strconv"
	"strings"
	"unsafe"
)

const (
	COLUMN_USERNAME_SIZE = 32
	COLUMN_EMAIL_SIZE    = 255
	ID_SIZE              = 4
	USERNAME_SIZE        = COLUMN_USERNAME_SIZE + 1
	EMAIL_SIZE           = COLUMN_EMAIL_SIZE + 1
	ID_OFFSET            = 0
	USERNAME_OFFSET      = ID_OFFSET + ID_SIZE
	EMAIL_OFFSET         = USERNAME_OFFSET + USERNAME_SIZE
	ROW_SIZE             = ID_SIZE + USERNAME_SIZE + EMAIL_SIZE
	PAGE_SIZE            = 4096
	TABLE_MAX_PAGES      = 100
)

type NodeType uint8

const (
	NODE_INTERNAL NodeType = iota
	NODE_LEAF
)

// Common Node Header Layout
const (
	NODE_TYPE_SIZE          = 1
	NODE_TYPE_OFFSET        = 0
	IS_ROOT_SIZE            = 1
	IS_ROOT_OFFSET          = NODE_TYPE_OFFSET + NODE_TYPE_SIZE
	PARENT_POINTER_SIZE     = 4
	PARENT_POINTER_OFFSET   = IS_ROOT_OFFSET + IS_ROOT_SIZE
	COMMON_NODE_HEADER_SIZE = NODE_TYPE_SIZE + IS_ROOT_SIZE + PARENT_POINTER_SIZE
)

// Leaf Node Header Layout
const (
	LEAF_NODE_NUM_CELLS_SIZE   = 4
	LEAF_NODE_NUM_CELLS_OFFSET = COMMON_NODE_HEADER_SIZE
	LEAF_NODE_HEADER_SIZE      = COMMON_NODE_HEADER_SIZE + LEAF_NODE_NUM_CELLS_SIZE
)

// Leaf Node Body Layout
const (
	LEAF_NODE_KEY_SIZE        = 4
	LEAF_NODE_KEY_OFFSET      = 0
	LEAF_NODE_VALUE_SIZE      = ROW_SIZE
	LEAF_NODE_VALUE_OFFSET    = LEAF_NODE_KEY_OFFSET + LEAF_NODE_KEY_SIZE
	LEAF_NODE_CELL_SIZE       = LEAF_NODE_KEY_SIZE + LEAF_NODE_VALUE_SIZE
	LEAF_NODE_SPACE_FOR_CELLS = PAGE_SIZE - LEAF_NODE_HEADER_SIZE
	LEAF_NODE_MAX_CELLS       = LEAF_NODE_SPACE_FOR_CELLS / LEAF_NODE_CELL_SIZE
)

type InputBuffer struct {
	buffer       string
	bufferLength int
	inputLength  int
}

type MetaCommandResult int

const (
	META_COMMAND_SUCCESS MetaCommandResult = iota
	META_COMMAND_UNRECOGNIZED_COMMAND
)

type PrepareResult int

const (
	PREPARE_SUCCESS PrepareResult = iota
	PREPARE_NEGATIVE_ID
	PREPARE_STRING_TOO_LONG
	PREPARE_SYNTAX_ERROR
	PREPARE_UNRECOGNIZED_STATEMENT
)

type StatementType int

const (
	STATEMENT_INSERT StatementType = iota
	STATEMENT_SELECT
)

type Row struct {
	id       uint32
	username [COLUMN_USERNAME_SIZE + 1]byte
	email    [COLUMN_EMAIL_SIZE + 1]byte
}

type Statement struct {
	typ         StatementType
	rowToInsert Row
}

type Pager struct {
	fileDescriptor *os.File
	fileLength     uint32
	numPages       uint32
	pages          [TABLE_MAX_PAGES][]byte
}

type Table struct {
	rootPageNum uint32
	pager       *Pager
}

type Cursor struct {
	table      *Table
	pageNum    uint32
	cellNum    uint32
	endOfTable bool // 表示最后一个元素之后的位置
}

type ExecuteResult int

const (
	EXECUTE_SUCCESS ExecuteResult = iota
	EXECUTE_TABLE_FULL
	EXECUTE_DUPLICATE_KEY
)

func newInputBuffer() *InputBuffer {
	buffer := ""
	return &InputBuffer{
		buffer:       buffer,
		bufferLength: 0,
		inputLength:  0,
	}
}

func leafNodeNumCells(node []byte) *uint32 {
	return (*uint32)(unsafe.Pointer(&node[LEAF_NODE_NUM_CELLS_OFFSET]))
}

func leafNodeCell(node []byte, cellNum uint32) []byte {
	offset := LEAF_NODE_HEADER_SIZE + cellNum*LEAF_NODE_CELL_SIZE
	return node[offset : offset+LEAF_NODE_CELL_SIZE]
}

func leafNodeKey(node []byte, cellNum uint32) *uint32 {
	offset := LEAF_NODE_HEADER_SIZE + cellNum*LEAF_NODE_CELL_SIZE
	return (*uint32)(unsafe.Pointer(&node[offset]))
}

func leafNodeValue(node []byte, cellNum uint32) []byte {
	offset := LEAF_NODE_HEADER_SIZE + cellNum*LEAF_NODE_CELL_SIZE + LEAF_NODE_KEY_SIZE
	return node[offset : offset+LEAF_NODE_VALUE_SIZE]
}

func printConstants() {
	fmt.Printf("ROW_SIZE: %d\n", ROW_SIZE)
	fmt.Printf("COMMON_NODE_HEADER_SIZE: %d\n", COMMON_NODE_HEADER_SIZE)
	fmt.Printf("LEAF_NODE_HEADER_SIZE: %d\n", LEAF_NODE_HEADER_SIZE)
	fmt.Printf("LEAF_NODE_CELL_SIZE: %d\n", LEAF_NODE_CELL_SIZE)
	fmt.Printf("LEAF_NODE_SPACE_FOR_CELLS: %d\n", LEAF_NODE_SPACE_FOR_CELLS)
	fmt.Printf("LEAF_NODE_MAX_CELLS: %d\n", LEAF_NODE_MAX_CELLS)
}

func printLeafNode(node []byte) {
	numCells := *leafNodeNumCells(node)
	fmt.Printf("leaf (size %d)\n", numCells)
	for i := uint32(0); i < numCells; i++ {
		key := *leafNodeKey(node, i)
		fmt.Printf("  - %d : %d\n", i, key)
	}
}

func printRow(row *Row) {
	//fmt.Printf("(%d, %s, %s)\n", row.id, row.username, row.email)
	fmt.Printf("(%d, %s, %s)\n", row.id, strings.TrimRight(string(row.username[:]), "\x00"), strings.TrimRight(string(row.email[:]), "\x00"))
}

func serializeRow(source *Row, destination []byte) {
	copy(destination[ID_OFFSET:], (*(*[ID_SIZE]byte)(unsafe.Pointer(&source.id)))[:])
	copy(destination[USERNAME_OFFSET:], source.username[:])
	copy(destination[EMAIL_OFFSET:], source.email[:])
}

func deserializeRow(source []byte, destination *Row) {
	destination.id = *(*uint32)(unsafe.Pointer(&source[ID_OFFSET]))
	copy(destination.username[:], source[USERNAME_OFFSET:USERNAME_OFFSET+USERNAME_SIZE])
	copy(destination.email[:], source[EMAIL_OFFSET:EMAIL_OFFSET+EMAIL_SIZE])
}

func getNodeType(node []byte) NodeType {
	return NodeType(node[NODE_TYPE_OFFSET])
}

func setNodeType(node []byte, nodeType NodeType) {
	node[NODE_TYPE_OFFSET] = byte(nodeType)
}

func initializeLeafNode(node []byte) {
	setNodeType(node, NODE_LEAF)
	*leafNodeNumCells(node) = 0
}

func getPage(pager *Pager, pageNum uint32) []byte {
	if pageNum > TABLE_MAX_PAGES {
		fmt.Printf("Tried to fetch page number out of bounds. %d > %d\n", pageNum, TABLE_MAX_PAGES)
		os.Exit(1)
	}

	if pager.pages[pageNum] == nil {
		// Cache miss. Allocate memory and load from file.
		page := make([]byte, PAGE_SIZE)
		numPages := pager.fileLength / PAGE_SIZE

		// We might save a partial page at the end of the file
		if pager.fileLength%PAGE_SIZE != 0 {
			numPages++
		}

		if pageNum <= numPages {
			_, err := pager.fileDescriptor.Seek(int64(pageNum*PAGE_SIZE), os.SEEK_SET)
			if err != nil {
				fmt.Printf("Error seeking: %v\n", err)
				os.Exit(1)
			}

			_, err = pager.fileDescriptor.Read(page)
			if err != nil && err != io.EOF {
				fmt.Printf("Error reading file: %v\n", err)
				os.Exit(1)
			}
		}

		pager.pages[pageNum] = page
		if pageNum >= pager.numPages {
			pager.numPages = pageNum + 1
		}
	}

	return pager.pages[pageNum]
}

func tableStart(table *Table) *Cursor {
	cursor := &Cursor{
		table:      table,
		pageNum:    table.rootPageNum,
		cellNum:    0,
		endOfTable: false,
	}

	rootNode := getPage(table.pager, table.rootPageNum)
	numCells := *leafNodeNumCells(rootNode)
	cursor.endOfTable = (numCells == 0)

	return cursor
}

func leafNodeFind(table *Table, pageNum, key uint32) *Cursor {
	node := getPage(table.pager, pageNum)
	numCells := *leafNodeNumCells(node)
	cursor := &Cursor{table: table, pageNum: pageNum}

	// Binary search
	minIndex := uint32(0)
	onePastMaxIndex := numCells
	for onePastMaxIndex != minIndex {
		index := (minIndex + onePastMaxIndex) / 2
		keyAtIndex := *leafNodeKey(node, index)
		if key == keyAtIndex {
			cursor.cellNum = index
			return cursor
		}
		if key < keyAtIndex {
			onePastMaxIndex = index
		} else {
			minIndex = index + 1
		}
	}

	cursor.cellNum = minIndex
	return cursor
}

func tableFind(table *Table, key uint32) *Cursor {
	rootPageNum := table.rootPageNum
	rootNode := getPage(table.pager, rootPageNum)
	nodeType := getNodeType(rootNode)

	if nodeType == NODE_LEAF {
		return leafNodeFind(table, rootPageNum, key)
	} else {
		fmt.Println("Need to implement searching an internal node")
		os.Exit(1)
	}
	return nil
}

func cursorValue(cursor *Cursor) []byte {
	pageNum := cursor.pageNum
	page := getPage(cursor.table.pager, pageNum)
	return leafNodeValue(page, cursor.cellNum)
}

func cursorAdvance(cursor *Cursor) {
	pageNum := cursor.pageNum
	node := getPage(cursor.table.pager, pageNum)
	cursor.cellNum += 1
	if cursor.cellNum >= *leafNodeNumCells(node) {
		cursor.endOfTable = true
	}
}

func pagerOpen(filename string) *Pager {
	fileDescriptor, err := os.OpenFile(filename, os.O_RDWR|os.O_CREATE, 0644)
	if err != nil {
		fmt.Printf("Unable to open file: %v\n", err)
		os.Exit(1)
	}

	fileLength, err := fileDescriptor.Seek(0, os.SEEK_END)
	if err != nil {
		fmt.Printf("Error seeking: %v\n", err)
		os.Exit(1)
	}

	pager := &Pager{
		fileDescriptor: fileDescriptor,
		fileLength:     uint32(fileLength),
		numPages:       uint32(fileLength / PAGE_SIZE),
	}

	if fileLength%PAGE_SIZE != 0 {
		fmt.Printf("Db file is not a whole number of pages. Corrupt file.\n")
		os.Exit(1)
	}

	for i := 0; i < TABLE_MAX_PAGES; i++ {
		pager.pages[i] = nil
	}

	return pager
}

func dbOpen(filename string) *Table {
	pager := pagerOpen(filename)

	table := &Table{
		rootPageNum: 0,
		pager:       pager,
	}

	if pager.numPages == 0 {
		// New database file. Initialize page 0 as leaf node.
		rootNode := getPage(pager, 0)
		initializeLeafNode(rootNode)
	}

	return table
}

func pagerFlush(pager *Pager, pageNum uint32) {
	if pager.pages[pageNum] == nil {
		fmt.Printf("Tried to flush null page\n")
		os.Exit(1)
	}

	offset, err := pager.fileDescriptor.Seek(int64(pageNum*PAGE_SIZE), os.SEEK_SET)
	if err != nil {
		fmt.Printf("Error seeking: %v\n", err)
		os.Exit(1)
	}

	if offset != int64(pageNum*PAGE_SIZE) {
		fmt.Printf("Seek offset does not match page start\n")
		os.Exit(1)
	}

	_, err = pager.fileDescriptor.Write(pager.pages[pageNum][:PAGE_SIZE])
	if err != nil {
		fmt.Printf("Error writing: %v\n", err)
		os.Exit(1)
	}
}

func dbClose(table *Table) {
	pager := table.pager

	for i := uint32(0); i < pager.numPages; i++ {
		if pager.pages[i] == nil {
			continue
		}
		pagerFlush(pager, i)
		pager.pages[i] = nil
	}

	err := pager.fileDescriptor.Close()
	if err != nil {
		fmt.Printf("Error closing db file: %v\n", err)
		os.Exit(1)
	}

	for i := 0; i < TABLE_MAX_PAGES; i++ {
		page := pager.pages[i]
		if page != nil {
			pager.pages[i] = nil
		}
	}

	os.Exit(0)
}

func printPrompt() {
	fmt.Print("db > ")
}

func readInput(reader *bufio.Reader, inputBuffer *InputBuffer) {
	// chatGPT init error, need to debug
	//reader := bufio.NewReader(os.Stdin)
	buffer, err := reader.ReadString('\n')
	if err != nil {
		fmt.Println("Error reading input: ", err.Error())
		os.Exit(1)
	}

	// Ignore newline character
	buffer = buffer[:len(buffer)-1]
	inputBuffer.inputLength = len(buffer)
	inputBuffer.buffer = buffer
}

func closeInputBuffer(inputBuffer *InputBuffer) {
	inputBuffer.buffer = ""
}

func doMetaCommand(inputBuffer *InputBuffer, table *Table) MetaCommandResult {
	if inputBuffer.buffer == ".exit" {
		closeInputBuffer(inputBuffer)
		dbClose(table)
		return META_COMMAND_SUCCESS
	} else if inputBuffer.buffer == ".btree" {
		fmt.Printf(("Tree:\n"))
		printLeafNode(getPage(table.pager, 0))
		return META_COMMAND_SUCCESS
	} else if inputBuffer.buffer == ".constants" {
		fmt.Printf(("Constants:\n"))
		printConstants()
		return META_COMMAND_SUCCESS
	} else {
		return META_COMMAND_UNRECOGNIZED_COMMAND
	}
}

func prepareInsert(inputBuffer *InputBuffer, statement *Statement) PrepareResult {
	statement.typ = STATEMENT_INSERT

	tokens := strings.Fields(inputBuffer.buffer)
	if len(tokens) != 4 {
		return PREPARE_SYNTAX_ERROR
	}

	id, err := strconv.Atoi(tokens[1])
	if err != nil {
		return PREPARE_NEGATIVE_ID
	}

	if id < 0 {
		return PREPARE_NEGATIVE_ID
	}

	if len(tokens[2]) > COLUMN_USERNAME_SIZE || len(tokens[3]) > COLUMN_EMAIL_SIZE {
		return PREPARE_STRING_TOO_LONG
	}

	statement.rowToInsert.id = uint32(id)
	copy(statement.rowToInsert.username[:], tokens[2])
	copy(statement.rowToInsert.email[:], tokens[3])

	return PREPARE_SUCCESS
}

func prepareStatement(inputBuffer *InputBuffer, statement *Statement) PrepareResult {
	tokens := strings.Fields(inputBuffer.buffer)

	if len(tokens) == 0 {
		return PREPARE_UNRECOGNIZED_STATEMENT
	}

	switch tokens[0] {
	case "insert":
		return prepareInsert(inputBuffer, statement)
	case "select":
		statement.typ = STATEMENT_SELECT
		return PREPARE_SUCCESS
	default:
		return PREPARE_UNRECOGNIZED_STATEMENT
	}
}

func leafNodeInsert(cursor *Cursor, key uint32, value *Row) {
	node := getPage(cursor.table.pager, cursor.pageNum)

	numCells := *leafNodeNumCells(node)
	if numCells >= LEAF_NODE_MAX_CELLS {
		// Node full
		fmt.Println("Need to implement splitting a leaf node.")
		os.Exit(1)
	}

	if cursor.cellNum < numCells {
		// Make room for new cell
		for i := numCells; i > cursor.cellNum; i-- {
			copy(leafNodeCell(node, i), leafNodeCell(node, i-1))
		}
	}

	*leafNodeNumCells(node) += 1
	*leafNodeKey(node, cursor.cellNum) = key
	serializeRow(value, leafNodeValue(node, cursor.cellNum))
}

func executeInsert(statement *Statement, table *Table) ExecuteResult {
	node := getPage(table.pager, table.rootPageNum)
	numCells := *leafNodeNumCells(node)
	if numCells >= LEAF_NODE_MAX_CELLS {
		return EXECUTE_TABLE_FULL
	}

	rowToInsert := &statement.rowToInsert
	keyToInsert := rowToInsert.id
	cursor := tableFind(table, keyToInsert)
	if cursor.cellNum < numCells {
		keyAtIndex := *leafNodeKey(node, cursor.cellNum)
		if keyAtIndex == keyToInsert {
			return EXECUTE_DUPLICATE_KEY
		}
	}

	leafNodeInsert(cursor, rowToInsert.id, rowToInsert)

	return EXECUTE_SUCCESS
}

func executeSelect(statement *Statement, table *Table) ExecuteResult {
	cursor := tableStart(table)
	var row Row
	for cursor.endOfTable == false {
		deserializeRow(cursorValue(cursor), &row)
		printRow(&row)
		cursorAdvance(cursor)
	}
	return EXECUTE_SUCCESS
}

func executeStatement(statement *Statement, table *Table) ExecuteResult {
	switch statement.typ {
	case STATEMENT_INSERT:
		return executeInsert(statement, table)
	case STATEMENT_SELECT:
		return executeSelect(statement, table)
	default:
		return EXECUTE_SUCCESS
	}
}

func main() {
	if len(os.Args) < 2 {
		fmt.Println("Must supply a database filename.")
		os.Exit(1)
	}

	filename := os.Args[1]
	table := dbOpen(filename)

	inputBuffer := newInputBuffer()
	reader := bufio.NewReader(os.Stdin)
	for {
		printPrompt()
		readInput(reader, inputBuffer)

		if inputBuffer.buffer[0] == '.' {
			switch doMetaCommand(inputBuffer, table) {
			case META_COMMAND_SUCCESS:
				continue
			case META_COMMAND_UNRECOGNIZED_COMMAND:
				fmt.Printf("Unrecognized command '%s'\n", inputBuffer.buffer)
				continue
			}
		}

		var statement Statement
		switch prepareStatement(inputBuffer, &statement) {
		case PREPARE_SUCCESS:
			break
		case PREPARE_NEGATIVE_ID:
			fmt.Println("ID must be positive.")
			continue
		case PREPARE_STRING_TOO_LONG:
			fmt.Println("String is too long.")
			continue
		case PREPARE_SYNTAX_ERROR:
			fmt.Println("Syntax error. Could not parse statement.")
			continue
		case PREPARE_UNRECOGNIZED_STATEMENT:
			fmt.Printf("Unrecognized keyword at start of '%s'.\n", inputBuffer.buffer)
			continue
		}

		switch executeStatement(&statement, table) {
		case EXECUTE_SUCCESS:
			fmt.Println("Executed.")
		case EXECUTE_TABLE_FULL:
			fmt.Println("Error: Table full.")
		case EXECUTE_DUPLICATE_KEY:
			fmt.Println("Error: Duplicate key.")
		}
	}
}

第十部分 - 分裂叶子节点

测试:

import sys,os
from util import run_script

# 测试btree有序结构
def test_btree_order_structure(db_file=""):
    script = [
        "insert 3 user3 person3@example.com",
        "insert 1 user1 person1@example.com",
        "insert 2 user2 person2@example.com",
        ".btree",
        ".exit"
    ]
    
    result = run_script(script,db_file=db_file)

    expected_result = [
        "db > Executed.",
        "db > Executed.",
        "db > Executed.",
        "db > Tree:",
        "- leaf (size 3)",
        "  - 1",
        "  - 2",
        "  - 3",
        "db > "
    ]

    print(f"result: {result}")
    assert result == expected_result
    print(f"{sys._getframe().f_code.co_name} passed")

# 测试打印3个叶子节点的btree
def test_print_structure_of_3_leaf_node_btree(db_file=""):
    script = [f"insert {i} user{i} person{i}@example.com" for i in range(1, 15)]
    script.append(".btree")
    script.append("insert 15 user15 person15@example.com")
    script.append(".exit")
    result = run_script(script,db_file=db_file,is_remove=True)

    expected_result = [
        "db > Tree:",
        "- internal (size 1)",
        "  - leaf (size 7)",
        "    - 1",
        "    - 2",
        "    - 3",
        "    - 4",
        "    - 5",
        "    - 6",
        "    - 7",
        "  - key 7",
        "  - leaf (size 7)",
        "    - 8",
        "    - 9",
        "    - 10",
        "    - 11",
        "    - 12",
        "    - 13",
        "    - 14",
        "db > Need to implement searching an internal node",
    ]

    print(f"result: {result}")
    assert result[14:] == expected_result
    print(f"{sys._getframe().f_code.co_name} passed")

if len(sys.argv)<2:
    print(f"need db file path")
    exit(0)
db_file = sys.argv[1]
if os.path.exists(db_file):
    print(f"{db_file} exists, remove it first")
    os.remove(db_file)

test_btree_order_structure(db_file)
test_print_structure_of_3_leaf_node_btree(db_file)

print("all tests passed.")

diff golang/9.go golang/10.go

62a63,84
> /*
>  * Leaf Node Split
>  */
> const LEAF_NODE_RIGHT_SPLIT_COUNT = (LEAF_NODE_MAX_CELLS + 1) / 2
> const LEAF_NODE_LEFT_SPLIT_COUNT = LEAF_NODE_MAX_CELLS + 1 - LEAF_NODE_RIGHT_SPLIT_COUNT
> 
> /*
>  * Internal Node Header Layout
>  */
> const INTERNAL_NODE_NUM_KEYS_SIZE = 4
> const INTERNAL_NODE_NUM_KEYS_OFFSET = COMMON_NODE_HEADER_SIZE
> const INTERNAL_NODE_RIGHT_CHILD_SIZE = 4
> const INTERNAL_NODE_RIGHT_CHILD_OFFSET = INTERNAL_NODE_NUM_KEYS_OFFSET + INTERNAL_NODE_NUM_KEYS_SIZE
> const INTERNAL_NODE_HEADER_SIZE = COMMON_NODE_HEADER_SIZE + INTERNAL_NODE_NUM_KEYS_SIZE + INTERNAL_NODE_RIGHT_CHILD_SIZE
> 
> /*
>  * Internal Node Body Layout
>  */
> const INTERNAL_NODE_KEY_SIZE = 4
> const INTERNAL_NODE_CHILD_SIZE = 4
> const INTERNAL_NODE_CELL_SIZE = INTERNAL_NODE_CHILD_SIZE + INTERNAL_NODE_KEY_SIZE
> 
168,173c190,192
< func printLeafNode(node []byte) {
<       numCells := *leafNodeNumCells(node)
<       fmt.Printf("leaf (size %d)\n", numCells)
<       for i := uint32(0); i < numCells; i++ {
<               key := *leafNodeKey(node, i)
<               fmt.Printf("  - %d : %d\n", i, key)
---
> func indent(level uint32) {
>       for i := uint32(0); i < level; i++ {
>               fmt.Print("  ")
199a219,239
> }
> 
> func isNodeRoot(node []byte) bool {
>       value := node[IS_ROOT_OFFSET]
>       return value != 0
> }
> 
> func setNodeRoot(node []byte, isRoot bool) {
>       if isRoot {
>               node[IS_ROOT_OFFSET] = 1
>       } else {
>               node[IS_ROOT_OFFSET] = 0
>       }
> }
> 
> func internalNodeNumKeys(node []byte) *uint32 {
>       return (*uint32)(unsafe.Pointer(&node[INTERNAL_NODE_NUM_KEYS_OFFSET]))
> }
> 
> func internalNodeRightChild(node []byte) *uint32 {
>       return (*uint32)(unsafe.Pointer(&node[INTERNAL_NODE_RIGHT_CHILD_OFFSET]))
201a242,246
> func internalNodeCell(node []byte, cellNum uint32) *uint32 {
>       offset := INTERNAL_NODE_HEADER_SIZE + cellNum*INTERNAL_NODE_CELL_SIZE
>       return (*uint32)(unsafe.Pointer(&node[offset]))
> }
> 
203a249
>       setNodeRoot(node, false)
206a253,258
> func initializeInternalNode(node []byte) {
>       setNodeType(node, NODE_INTERNAL)
>       setNodeRoot(node, false)
>       *internalNodeNumKeys(node) = 0
> }
> 
243a296,341
> }
> 
> func internalNodeChild(node []byte, childNum uint32) *uint32 {
>       numKeys := *internalNodeNumKeys(node)
>       if childNum > numKeys {
>               fmt.Printf("Tried to access childNum %d > numKeys %d\n", childNum, numKeys)
>               os.Exit(1)
>       }
>       if childNum == numKeys {
>               return internalNodeRightChild(node)
>       }
>       return internalNodeCell(node, childNum)
> }
> 
> func internalNodeKey(node []byte, keyNum uint32) *uint32 {
>       offset := INTERNAL_NODE_HEADER_SIZE + keyNum*INTERNAL_NODE_CELL_SIZE + INTERNAL_NODE_CHILD_SIZE
>       return (*uint32)(unsafe.Pointer(&node[offset]))
> }
> 
> func printTree(pager *Pager, pageNum, indentationLevel uint32) {
>       node := getPage(pager, pageNum)
>       numKeys, child := uint32(0), uint32(0)
> 
>       switch getNodeType(node) {
>       case NODE_LEAF:
>               numKeys = *leafNodeNumCells(node)
>               indent(indentationLevel)
>               fmt.Printf("- leaf (size %d)\n", numKeys)
>               for i := uint32(0); i < numKeys; i++ {
>                       indent(indentationLevel + 1)
>                       fmt.Printf("- %d\n", *leafNodeKey(node, i))
>               }
>       case NODE_INTERNAL:
>               numKeys = *internalNodeNumKeys(node)
>               indent(indentationLevel)
>               fmt.Printf("- internal (size %d)\n", numKeys)
>               for i := uint32(0); i < numKeys; i++ {
>                       child = *internalNodeChild(node, i)
>                       printTree(pager, child, indentationLevel+1)
> 
>                       indent(indentationLevel + 1)
>                       fmt.Printf("- key %d\n", *internalNodeKey(node, i))
>               }
>               child = *internalNodeRightChild(node)
>               printTree(pager, child, indentationLevel+1)
>       }
358a457
>               setNodeRoot(rootNode, true)
445c544
<               printLeafNode(getPage(table.pager, 0))
---
>               printTree(table.pager, 0, 0)
499a599,616
> }
> 
> func getNodeMaxKey(node []byte) uint32 {
>       switch getNodeType(node) {
>       case NODE_INTERNAL:
>               numKeys := *internalNodeNumKeys(node)
>               return *internalNodeKey(node, numKeys-1)
>       case NODE_LEAF:
>               numCells := *leafNodeNumCells(node)
>               return *leafNodeKey(node, numCells-1)
>       default:
>               // Handle other node types if needed
>               return 0 // or appropriate default value
>       }
> }
> 
> func getUnusedPageNum(pager *Pager) uint32 {
>       return pager.numPages
501a619,683
> func createNewRoot(table *Table, rightChildPageNum uint32) {
>       root := getPage(table.pager, table.rootPageNum)
>       //rightChild := getPage(table.pager, rightChildPageNum)
>       leftChildPageNum := getUnusedPageNum(table.pager)
>       leftChild := getPage(table.pager, leftChildPageNum)
> 
>       // Left child gets data copied from the old root
>       copy(leftChild, root[:])
>       setNodeRoot(leftChild, false)
> 
>       // Root becomes a new internal node with one key and two children
>       initializeInternalNode(root)
>       setNodeRoot(root, true)
>       *internalNodeNumKeys(root) = 1
>       *internalNodeChild(root, 0) = leftChildPageNum
> 
>       leftChildMaxKey := getNodeMaxKey(leftChild)
>       *internalNodeKey(root, 0) = leftChildMaxKey
>       *internalNodeRightChild(root) = rightChildPageNum
> }
> 
> // 创建一个新节点并将一半单元格移动过去。
> // 在两个节点中的一个中插入新值。
> // 更新父节点或创建一个新的父节点。
> func leafNodeSplitAndInsert(cursor *Cursor, key uint32, value *Row) {
>       oldNode := getPage(cursor.table.pager, cursor.pageNum)
>       newPageNum := getUnusedPageNum(cursor.table.pager)
>       newNode := getPage(cursor.table.pager, newPageNum)
>       initializeLeafNode(newNode)
> 
>       /*
>         所有现有键以及新键应该均匀分布
>         在旧(左)和新(右)节点之间。
>         从右侧开始,将每个键移动到正确的位置。
>       */
>       for i := LEAF_NODE_MAX_CELLS; i >= 0; i-- {
>               var destinationNode []byte
>               if i >= LEAF_NODE_LEFT_SPLIT_COUNT {
>                       destinationNode = newNode
>               } else {
>                       destinationNode = oldNode
>               }
>               indexWithinNode := i % LEAF_NODE_LEFT_SPLIT_COUNT
>               destination := leafNodeCell(destinationNode, uint32(indexWithinNode))
> 
>               if i == int(cursor.cellNum) {
>                       serializeRow(value, destination)
>               } else if i > int(cursor.cellNum) {
>                       copy(destination, leafNodeCell(oldNode, uint32(i-1))[:LEAF_NODE_CELL_SIZE])
>               } else {
>                       copy(destination, leafNodeCell(oldNode, uint32(i))[:LEAF_NODE_CELL_SIZE])
>               }
>       }
> 
>       /* 在两个叶子节点上更新单元格计数 */
>       *leafNodeNumCells(oldNode) = LEAF_NODE_LEFT_SPLIT_COUNT
>       *leafNodeNumCells(newNode) = LEAF_NODE_RIGHT_SPLIT_COUNT
>       if isNodeRoot(oldNode) {
>               createNewRoot(cursor.table, newPageNum)
>       } else {
>               fmt.Println("Need to implement updating parent after split")
>               os.Exit(1)
>       }
> }
> 
507,509c689,690
<               // Node full
<               fmt.Println("Need to implement splitting a leaf node.")
<               os.Exit(1)
---
>               leafNodeSplitAndInsert(cursor, key, value)
>               return
527,529d707
<       if numCells >= LEAF_NODE_MAX_CELLS {
<               return EXECUTE_TABLE_FULL
<       }

第十一部分 - 递归搜索 B 树

测试:

import sys,os
from util import run_script

# 测试打印3个叶子节点的btree
def test_print_structure_of_3_leaf_node_btree(db_file=""):
    script = [f"insert {i} user{i} person{i}@example.com" for i in range(1, 15)]
    script.append(".btree")
    script.append("insert 15 user15 person15@example.com")
    script.append(".exit")
    result = run_script(script,db_file=db_file)

    expected_result = [
        "db > Tree:",
        "- internal (size 1)",
        "  - leaf (size 7)",
        "    - 1",
        "    - 2",
        "    - 3",
        "    - 4",
        "    - 5",
        "    - 6",
        "    - 7",
        "  - key 7",
        "  - leaf (size 7)",
        "    - 8",
        "    - 9",
        "    - 10",
        "    - 11",
        "    - 12",
        "    - 13",
        "    - 14",
        "db > Executed.",
        "db > ",
    ]

    print(f"result: {result}")
    assert result[14:] == expected_result
    print(f"{sys._getframe().f_code.co_name} passed")

# 数据库现在可以容纳1400行,因为我们将最大页面数设置为100,并且一页可以容纳14行。
# 测试表已满的情况
def test_prints_error_message_when_table_is_full(db_file=""):
    script = [f"insert {i} user{i} person{i}@example.com" for i in range(1, 1402)]
    script.append(".exit")
    result = run_script(script,db_file=db_file,is_remove=True)
    expected_result = [
        "db > Executed.", 
        "db > Need to implement updating parent after split",
    ]
    print(f"result: {result}")
    assert result[-2:] == expected_result, "Test failed"
    print(f"{sys._getframe().f_code.co_name} passed")


if len(sys.argv)<2:
    print(f"need db file path")
    exit(0)
db_file = sys.argv[1]
if os.path.exists(db_file):
    print(f"{db_file} exists, remove it first")
    os.remove(db_file)

test_print_structure_of_3_leaf_node_btree(db_file)
test_prints_error_message_when_table_is_full(db_file)

print("all tests passed.")

diff golang/10.go golang/11.go

383a384,391
> 
> func internalNodeFind(table *Table, pageNum, key uint32) *Cursor {
>       node := getPage(table.pager, pageNum)
>       numKeys := *internalNodeNumKeys(node)
> 
>       // Binary search to find index of child to search
>       minIndex := uint32(0)
>       maxIndex := numKeys // there is one more child than key
384a393,416
>       for minIndex != maxIndex {
>               index := (minIndex + maxIndex) / 2
>               keyToRight := *internalNodeKey(node, index)
>               if keyToRight >= key {
>                       maxIndex = index
>               } else {
>                       minIndex = index + 1
>               }
>       }
> 
>       childNum := *internalNodeChild(node, minIndex)
>       child := getPage(table.pager, childNum)
> 
>       switch getNodeType(child) {
>       case NODE_LEAF:
>               return leafNodeFind(table, childNum, key)
>       case NODE_INTERNAL:
>               return internalNodeFind(table, childNum, key)
>       default:
>               // Handle other node types if needed
>               return nil
>       }
> }
> 
393,394c425
<               fmt.Println("Need to implement searching an internal node")
<               os.Exit(1)
---
>               return internalNodeFind(table, rootPageNum, key)

第十二部分 - 扫描多层 B 树

测试:

import sys,os
from util import run_script

def test_prints_all_rows_in_multi_level_tree(db_file=""):
    script = []
    for i in range(1, 16):
        script.append(f"insert {i} user{i} person{i}@example.com")
    script.append("select")
    script.append(".exit")
    result = run_script(script,db_file=db_file)

    expected_output = [
        "db > (1, user1, person1@example.com)",
        "(2, user2, person2@example.com)",
        "(3, user3, person3@example.com)",
        "(4, user4, person4@example.com)",
        "(5, user5, person5@example.com)",
        "(6, user6, person6@example.com)",
        "(7, user7, person7@example.com)",
        "(8, user8, person8@example.com)",
        "(9, user9, person9@example.com)",
        "(10, user10, person10@example.com)",
        "(11, user11, person11@example.com)",
        "(12, user12, person12@example.com)",
        "(13, user13, person13@example.com)",
        "(14, user14, person14@example.com)",
        "(15, user15, person15@example.com)",
        "Executed.", "db > ",
    ]
    print(f"result: {result}")
    assert result[15:] == expected_output
    print(f"{sys._getframe().f_code.co_name} passed")

# 测试打印常量
def test_print_constants(db_file=""):
    script = [
        ".constants",
        ".exit",
    ]
    result = run_script(script,db_file=db_file)

    expected_result = [
        "db > Constants:",
        "ROW_SIZE: 293",
        "COMMON_NODE_HEADER_SIZE: 6",
        "LEAF_NODE_HEADER_SIZE: 14",
        "LEAF_NODE_CELL_SIZE: 297",
        "LEAF_NODE_SPACE_FOR_CELLS: 4082",
        "LEAF_NODE_MAX_CELLS: 13",
        "db > ",
    ]

    assert result == expected_result
    print(f"{sys._getframe().f_code.co_name} passed")

if len(sys.argv)<2:
    print(f"need db file path")
    exit(0)
db_file = sys.argv[1]
if os.path.exists(db_file):
    print(f"{db_file} exists, remove it first")
    os.remove(db_file)

test_prints_all_rows_in_multi_level_tree(db_file=db_file)
test_print_constants(db_file=db_file)

print("all tests passed.")

diff golang/11.go golang/12.go

49c49,51
<       LEAF_NODE_HEADER_SIZE      = COMMON_NODE_HEADER_SIZE + LEAF_NODE_NUM_CELLS_SIZE
---
>       LEAF_NODE_NEXT_LEAF_SIZE   = 4
>       LEAF_NODE_NEXT_LEAF_OFFSET = LEAF_NODE_NUM_CELLS_OFFSET + LEAF_NODE_NUM_CELLS_SIZE
>       LEAF_NODE_HEADER_SIZE      = COMMON_NODE_HEADER_SIZE + LEAF_NODE_NUM_CELLS_SIZE + LEAF_NODE_NEXT_LEAF_SIZE
178a181,184
> }
> 
> func leafNodeNextLeaf(node []byte) *uint32 {
>       return (*uint32)(unsafe.Pointer(&node[LEAF_NODE_NEXT_LEAF_OFFSET]))
250a257
>       *leafNodeNextLeaf(node) = 0 // 0 表示无兄弟节点
341,349d347
<       }
< }
< 
< func tableStart(table *Table) *Cursor {
<       cursor := &Cursor{
<               table:      table,
<               pageNum:    table.rootPageNum,
<               cellNum:    0,
<               endOfTable: false,
351,356d348
< 
<       rootNode := getPage(table.pager, table.rootPageNum)
<       numCells := *leafNodeNumCells(rootNode)
<       cursor.endOfTable = (numCells == 0)
< 
<       return cursor
441c433,441
<               cursor.endOfTable = true
---
>               /* 前进到下一个叶子节点 */
>               nextPageNum := *leafNodeNextLeaf(node)
>               if nextPageNum == 0 {
>                       /* 这是最右边的叶子节点 */
>                       cursor.endOfTable = true
>               } else {
>                       cursor.pageNum = nextPageNum
>                       cursor.cellNum = 0
>               }
678a679,680
>       *leafNodeNextLeaf(newNode) = *leafNodeNextLeaf(oldNode)
>       *leafNodeNextLeaf(oldNode) = newPageNum
696c698,699
<                       serializeRow(value, destination)
---
>                       serializeRow(value, leafNodeValue(destinationNode, uint32(indexWithinNode)))
>                       *leafNodeKey(destinationNode, uint32(indexWithinNode)) = key
754a758,766
> func tableStart(table *Table) *Cursor {
>       cursor := tableFind(table, 0)
>       node := getPage(table.pager, cursor.pageNum)
>       numCells := *leafNodeNumCells(node)
>       cursor.endOfTable = numCells == 0
> 
>       return cursor
> }
> 

第十三部分 - 分裂叶子节点后更新父节点

测试:

import sys,os
from util import run_script

# 数据库现在可以容纳1400行,因为我们将最大页面数设置为100,并且一页可以容纳14行。
# 测试表已满的情况
def test_prints_error_message_when_table_is_full(db_file=""):
    script = [f"insert {i} user{i} person{i}@example.com" for i in range(1, 1402)]
    script.append(".exit")
    result = run_script(script,db_file=db_file,is_remove=True)
    expected_result = [
        "db > Executed.", 
        "db > Need to implement splitting internal node",
    ]
    print(f"result: {result}")
    assert result[-2:] == expected_result, "Test failed"
    print(f"{sys._getframe().f_code.co_name} passed")

# 测试4个叶子节点的B+树的结构
def test_prints_structure_of_4_leaf_node_btree(db_file=""):
    script = [
        "insert 18 user18 person18@example.com",
        "insert 7 user7 person7@example.com",
        "insert 10 user10 person10@example.com",
        "insert 29 user29 person29@example.com",
        "insert 23 user23 person23@example.com",
        "insert 4 user4 person4@example.com",
        "insert 14 user14 person14@example.com",
        "insert 30 user30 person30@example.com",
        "insert 15 user15 person15@example.com",
        "insert 26 user26 person26@example.com",
        "insert 22 user22 person22@example.com",
        "insert 19 user19 person19@example.com",
        "insert 2 user2 person2@example.com",
        "insert 1 user1 person1@example.com",
        "insert 21 user21 person21@example.com",
        "insert 11 user11 person11@example.com",
        "insert 6 user6 person6@example.com",
        "insert 20 user20 person20@example.com",
        "insert 5 user5 person5@example.com",
        "insert 8 user8 person8@example.com",
        "insert 9 user9 person9@example.com",
        "insert 3 user3 person3@example.com",
        "insert 12 user12 person12@example.com",
        "insert 27 user27 person27@example.com",
        "insert 17 user17 person17@example.com",
        "insert 16 user16 person16@example.com",
        "insert 13 user13 person13@example.com",
        "insert 24 user24 person24@example.com",
        "insert 25 user25 person25@example.com",
        "insert 28 user28 person28@example.com",
        ".btree",
        ".exit",
    ]
    result = run_script(script,db_file=db_file)

    expected_output = [
        "db > Tree:",
        "- internal (size 3)",
        "  - leaf (size 7)",
        "    - 1",
        "    - 2",
        "    - 3",
        "    - 4",
        "    - 5",
        "    - 6",
        "    - 7",
        "  - key 7",
        "  - leaf (size 8)",
        "    - 8",
        "    - 9",
        "    - 10",
        "    - 11",
        "    - 12",
        "    - 13",
        "    - 14",
        "    - 15",
        "  - key 15",
        "  - leaf (size 7)",
        "    - 16",
        "    - 17",
        "    - 18",
        "    - 19",
        "    - 20",
        "    - 21",
        "    - 22",
        "  - key 22",
        "  - leaf (size 8)",
        "    - 23",
        "    - 24",
        "    - 25",
        "    - 26",
        "    - 27",
        "    - 28",
        "    - 29",
        "    - 30",
        "db > ",
    ]

    print(f"result: {result}")
    assert result[30:] == expected_output
    print(f"{sys._getframe().f_code.co_name} passed")


if len(sys.argv)<2:
    print(f"need db file path")
    exit(0)
db_file = sys.argv[1]
if os.path.exists(db_file):
    print(f"{db_file} exists, remove it first")
    os.remove(db_file)

test_prints_error_message_when_table_is_full(db_file)
test_prints_structure_of_4_leaf_node_btree(db_file)

print("all tests passed.")

diff golang/12.go golang/13.go

85a86
> const INTERNAL_NODE_MAX_CELLS = 3
184a186,189
> }
> 
> func nodeParent(node []byte) *uint32 {
>       return (*uint32)(unsafe.Pointer(&node[PARENT_POINTER_OFFSET]))
319a325,399
> }
> 
> // 返回应包含给定键的子节点的索引。
> func internalNodeFindChild(node []byte, key uint32) uint32 {
>       numKeys := *internalNodeNumKeys(node)
> 
>       // Binary search
>       minIndex := uint32(0)
>       maxIndex := numKeys // there is one more child than key
> 
>       for minIndex != maxIndex {
>               index := (minIndex + maxIndex) / 2
>               keyToRight := *internalNodeKey(node, index)
> 
>               if keyToRight >= key {
>                       maxIndex = index
>               } else {
>                       minIndex = index + 1
>               }
>       }
> 
>       return minIndex
> }
> 
> func updateInternalNodeKey(node []byte, oldKey, newKey uint32) {
>       oldChildIndex := internalNodeFindChild(node, oldKey)
>       *internalNodeKey(node, oldChildIndex) = newKey
> }
> 
> func getNodeMaxKey(node []byte) uint32 {
>       switch getNodeType(node) {
>       case NODE_INTERNAL:
>               return *internalNodeKey(node, *internalNodeNumKeys(node)-1)
>       case NODE_LEAF:
>               return *leafNodeKey(node, *leafNodeNumCells(node)-1)
>       default:
>               // Handle other node types if needed
>               return 0 // or appropriate default value
>       }
> }
> 
> func internalNodeInsert(table *Table, parentPageNum, childPageNum uint32) {
>       parent := getPage(table.pager, parentPageNum)
>       child := getPage(table.pager, childPageNum)
>       childMaxKey := getNodeMaxKey(child)
>       index := internalNodeFindChild(parent, childMaxKey)
> 
>       originalNumKeys := *internalNodeNumKeys(parent)
>       *internalNodeNumKeys(parent) = originalNumKeys + 1
> 
>       if originalNumKeys >= INTERNAL_NODE_MAX_CELLS {
>               fmt.Println("Need to implement splitting internal node")
>               os.Exit(1)
>       }
> 
>       rightChildPageNum := *internalNodeRightChild(parent)
>       rightChild := getPage(table.pager, rightChildPageNum)
> 
>       if childMaxKey > getNodeMaxKey(rightChild) {
>               // Replace right child
>               *internalNodeChild(parent, originalNumKeys) = rightChildPageNum
>               *internalNodeKey(parent, originalNumKeys) = getNodeMaxKey(rightChild)
>               *internalNodeRightChild(parent) = childPageNum
>       } else {
>               // Make space for the new cell
>               for i := originalNumKeys; i > index; i-- {
>                       //destination := internalNodeCell(parent, i)
>                       //source := internalNodeCell(parent, i-1)
>                       // c: memcpy(destination, source, INTERNAL_NODE_CELL_SIZE);
>                       //copy(destination, source)
>                       *internalNodeCell(parent, i) = *internalNodeCell(parent, i-1)
>               }
>               *internalNodeChild(parent, index) = childPageNum
>               *internalNodeKey(parent, index) = childMaxKey
>       }
632,645d711
< func getNodeMaxKey(node []byte) uint32 {
<       switch getNodeType(node) {
<       case NODE_INTERNAL:
<               numKeys := *internalNodeNumKeys(node)
<               return *internalNodeKey(node, numKeys-1)
<       case NODE_LEAF:
<               numCells := *leafNodeNumCells(node)
<               return *leafNodeKey(node, numCells-1)
<       default:
<               // Handle other node types if needed
<               return 0 // or appropriate default value
<       }
< }
< 
652c718
<       //rightChild := getPage(table.pager, rightChildPageNum)
---
>       rightChild := getPage(table.pager, rightChildPageNum)
668a735,736
>       *nodeParent(leftChild) = table.rootPageNum
>       *nodeParent(rightChild) = table.rootPageNum
675a744
>       oldMax := getNodeMaxKey(oldNode)
678a748
>       *nodeParent(newNode) = *nodeParent(oldNode)
713,714c783,788
<               fmt.Println("Need to implement updating parent after split")
<               os.Exit(1)
---
>               parentPageNum := *nodeParent(oldNode)
>               newMax := getNodeMaxKey(oldNode)
>               parent := getPage(cursor.table.pager, parentPageNum)
> 
>               updateInternalNodeKey(parent, oldMax, newMax)
>               internalNodeInsert(cursor.table, parentPageNum, newPageNum)

第十四部分 - 拆分内部节点

测试:

import sys,os
from util import run_script

# 测试7个叶子节点的B+树的结构
def test_prints_structure_of_7_leaf_node_btree(db_file=""):
    script = [
        "insert 58 user58 person58@example.com",
        "insert 56 user56 person56@example.com",
        "insert 8 user8 person8@example.com",
        "insert 54 user54 person54@example.com",
        "insert 77 user77 person77@example.com",
        "insert 7 user7 person7@example.com",
        "insert 25 user25 person25@example.com",
        "insert 71 user71 person71@example.com",
        "insert 13 user13 person13@example.com",
        "insert 22 user22 person22@example.com",
        "insert 53 user53 person53@example.com",
        "insert 51 user51 person51@example.com",
        "insert 59 user59 person59@example.com",
        "insert 32 user32 person32@example.com",
        "insert 36 user36 person36@example.com",
        "insert 79 user79 person79@example.com",
        "insert 10 user10 person10@example.com",
        "insert 33 user33 person33@example.com",
        "insert 20 user20 person20@example.com",
        "insert 4 user4 person4@example.com",
        "insert 35 user35 person35@example.com",
        "insert 76 user76 person76@example.com",
        "insert 49 user49 person49@example.com",
        "insert 24 user24 person24@example.com",
        "insert 70 user70 person70@example.com",
        "insert 48 user48 person48@example.com",
        "insert 39 user39 person39@example.com",
        "insert 15 user15 person15@example.com",
        "insert 47 user47 person47@example.com",
        "insert 30 user30 person30@example.com",
        "insert 86 user86 person86@example.com",
        "insert 31 user31 person31@example.com",
        "insert 68 user68 person68@example.com",
        "insert 37 user37 person37@example.com",
        "insert 66 user66 person66@example.com",
        "insert 63 user63 person63@example.com",
        "insert 40 user40 person40@example.com",
        "insert 78 user78 person78@example.com",
        "insert 19 user19 person19@example.com",
        "insert 46 user46 person46@example.com",
        "insert 14 user14 person14@example.com",
        "insert 81 user81 person81@example.com",
        "insert 72 user72 person72@example.com",
        "insert 6 user6 person6@example.com",
        "insert 50 user50 person50@example.com",
        "insert 85 user85 person85@example.com",
        "insert 67 user67 person67@example.com",
        "insert 2 user2 person2@example.com",
        "insert 55 user55 person55@example.com",
        "insert 69 user69 person69@example.com",
        "insert 5 user5 person5@example.com",
        "insert 65 user65 person65@example.com",
        "insert 52 user52 person52@example.com",
        "insert 1 user1 person1@example.com",
        "insert 29 user29 person29@example.com",
        "insert 9 user9 person9@example.com",
        "insert 43 user43 person43@example.com",
        "insert 75 user75 person75@example.com",
        "insert 21 user21 person21@example.com",
        "insert 82 user82 person82@example.com",
        "insert 12 user12 person12@example.com",
        "insert 18 user18 person18@example.com",
        "insert 60 user60 person60@example.com",
        "insert 44 user44 person44@example.com",
        ".btree",
        ".exit",
    ]
    result = run_script(script,db_file=db_file)

    expected_output = [
        "db > Tree:",
        "- internal (size 1)",
        "  - internal (size 2)",
        "    - leaf (size 7)",
        "      - 1",
        "      - 2",
        "      - 4",
        "      - 5",
        "      - 6",
        "      - 7",
        "      - 8",
        "    - key 8",
        "    - leaf (size 11)",
        "      - 9",
        "      - 10",
        "      - 12",
        "      - 13",
        "      - 14",
        "      - 15",
        "      - 18",
        "      - 19",
        "      - 20",
        "      - 21",
        "      - 22",
        "    - key 22",
        "    - leaf (size 8)",
        "      - 24",
        "      - 25",
        "      - 29",
        "      - 30",
        "      - 31",
        "      - 32",
        "      - 33",
        "      - 35",
        "  - key 35",
        "  - internal (size 3)",
        "    - leaf (size 12)",
        "      - 36",
        "      - 37",
        "      - 39",
        "      - 40",
        "      - 43",
        "      - 44",
        "      - 46",
        "      - 47",
        "      - 48",
        "      - 49",
        "      - 50",
        "      - 51",
        "    - key 51",
        "    - leaf (size 11)",
        "      - 52",
        "      - 53",
        "      - 54",
        "      - 55",
        "      - 56",
        "      - 58",
        "      - 59",
        "      - 60",
        "      - 63",
        "      - 65",
        "      - 66",
        "    - key 66",
        "    - leaf (size 7)",
        "      - 67",
        "      - 68",
        "      - 69",
        "      - 70",
        "      - 71",
        "      - 72",
        "      - 75",
        "    - key 75",
        "    - leaf (size 8)",
        "      - 76",
        "      - 77",
        "      - 78",
        "      - 79",
        "      - 81",
        "      - 82",
        "      - 85",
        "      - 86",
        "db > ",
    ]

    #print(f"result: {result}")
    print(f"result[64:]: {result[64:]}")
    assert result[64:] == expected_output
    print(f"{sys._getframe().f_code.co_name} passed")


if len(sys.argv)<2:
    print(f"need db file path")
    exit(0)
db_file = sys.argv[1]
if os.path.exists(db_file):
    print(f"{db_file} exists, remove it first")
    os.remove(db_file)

test_prints_structure_of_7_leaf_node_btree(db_file)

print("all tests passed.")

diff golang/13.go golang/14.go

6a7
>       "math"
85a87,88
> 
> /* 为了测试,保持较小 */
86a90,91
> 
> const INVALID_PAGE_NUM = math.MaxUint32
268a274,277
>       /*
>         由于根页码是0,因此在初始化内部节点时,如果不将其右子节点初始化为无效的页码,可能会导致右子节点为0,这将使该节点成为根节点的父节点。
>       */
>       *internalNodeRightChild(node) = INVALID_PAGE_NUM
317c326,331
<               return internalNodeRightChild(node)
---
>               rightChild := internalNodeRightChild(node)
>               if *rightChild == INVALID_PAGE_NUM {
>                       fmt.Printf("Tried to access right child of node, but was invalid page\n")
>                       os.Exit(1)
>               }
>               return rightChild
319c333,338
<       return internalNodeCell(node, childNum)
---
>       child := internalNodeCell(node, childNum)
>       if *child == INVALID_PAGE_NUM {
>               fmt.Printf("Tried to access child %d of node, but was invalid page\n", childNum)
>               os.Exit(1)
>       }
>       return child
354,358c373,374
< func getNodeMaxKey(node []byte) uint32 {
<       switch getNodeType(node) {
<       case NODE_INTERNAL:
<               return *internalNodeKey(node, *internalNodeNumKeys(node)-1)
<       case NODE_LEAF:
---
> func getNodeMaxKey(pager *Pager, node []byte) uint32 {
>       if getNodeType(node) == NODE_LEAF {
360,362c376,407
<       default:
<               // Handle other node types if needed
<               return 0 // or appropriate default value
---
>       }
>       rightChild := getPage(pager, *internalNodeRightChild(node))
>       return getNodeMaxKey(pager, rightChild)
> }
> 
> // 处理根节点的拆分。
> // 将旧根复制到新页,成为左子节点。
> // 重新初始化根页以包含新根节点。
> // 新根节点指向两个子节点。
> func createNewRoot(table *Table, rightChildPageNum uint32) {
>       root := getPage(table.pager, table.rootPageNum)
>       rightChild := getPage(table.pager, rightChildPageNum)
>       leftChildPageNum := getUnusedPageNum(table.pager)
>       leftChild := getPage(table.pager, leftChildPageNum)
> 
>       if getNodeType(root) == NODE_INTERNAL {
>               initializeInternalNode(rightChild)
>               initializeInternalNode(leftChild)
>       }
> 
>       // Left child has data copied from the old root
>       copy(leftChild, root)
>       setNodeRoot(leftChild, false)
> 
>       if getNodeType(leftChild) == NODE_INTERNAL {
>               var child []byte
>               for i := uint32(0); i < *internalNodeNumKeys(leftChild); i++ {
>                       child = getPage(table.pager, *internalNodeChild(leftChild, i))
>                       *nodeParent(child) = leftChildPageNum
>               }
>               child = getPage(table.pager, *internalNodeRightChild(leftChild))
>               *nodeParent(child) = leftChildPageNum
363a409,420
> 
>       // Root becomes a new internal node with one key and two children
>       initializeInternalNode(root)
>       setNodeRoot(root, true)
>       *internalNodeNumKeys(root) = 1
>       *internalNodeChild(root, 0) = leftChildPageNum
>       leftChildMaxKey := getNodeMaxKey(table.pager, leftChild)
> 
>       *internalNodeKey(root, 0) = leftChildMaxKey
>       *internalNodeRightChild(root) = rightChildPageNum
>       *nodeParent(leftChild) = table.rootPageNum
>       *nodeParent(rightChild) = table.rootPageNum
365a423
> // 向父节点添加一个新的子节点/键对,对应于子节点
369c427
<       childMaxKey := getNodeMaxKey(child)
---
>       childMaxKey := getNodeMaxKey(table.pager, child)
373,374d430
<       *internalNodeNumKeys(parent) = originalNumKeys + 1
< 
376,377c432,433
<               fmt.Println("Need to implement splitting internal node")
<               os.Exit(1)
---
>               internalNodeSplitAndInsert(table, parentPageNum, childPageNum)
>               return
380a437,442
>       // 具有右子节点为INVALID_PAGE_NUM的内部节点为空
>       if rightChildPageNum == INVALID_PAGE_NUM {
>               *internalNodeRightChild(parent) = childPageNum
>               return
>       }
> 
381a444,450
>       /*
>         如果我们已经达到节点的最大单元格数,就不能在分裂之前递增。
>         在没有插入新的键/子节点对的情况下递增,并立即调用
>         `internal_node_split_and_insert` 会导致在 `(max_cells + 1)`
>         处创建一个新的键,其值未初始化。
>       */
>       *internalNodeNumKeys(parent) = originalNumKeys + 1
383c452
<       if childMaxKey > getNodeMaxKey(rightChild) {
---
>       if childMaxKey > getNodeMaxKey(table.pager, rightChild) {
386c455
<               *internalNodeKey(parent, originalNumKeys) = getNodeMaxKey(rightChild)
---
>               *internalNodeKey(parent, originalNumKeys) = getNodeMaxKey(table.pager, rightChild)
391,392c460,461
<                       //destination := internalNodeCell(parent, i)
<                       //source := internalNodeCell(parent, i-1)
---
>                       destination := internalNodeCell(parent, i)
>                       source := internalNodeCell(parent, i-1)
394,395c463,464
<                       //copy(destination, source)
<                       *internalNodeCell(parent, i) = *internalNodeCell(parent, i-1)
---
>                       copy((*(*[INTERNAL_NODE_CELL_SIZE]byte)(unsafe.Pointer(destination)))[:], (*(*[INTERNAL_NODE_CELL_SIZE]byte)(unsafe.Pointer(source)))[:])
>                       //*internalNodeCell(parent, i) = *internalNodeCell(parent, i-1)
400a470,473
> func internalNodeSplitAndInsert(table *Table, parentPageNum, childPageNum uint32) {
>       oldPageNum := parentPageNum
>       oldNode := getPage(table.pager, parentPageNum)
>       oldMax := getNodeMaxKey(table.pager, oldNode)
401a475,543
>       child := getPage(table.pager, childPageNum)
>       childMax := getNodeMaxKey(table.pager, child)
> 
>       newPageNum := getUnusedPageNum(table.pager)
> 
>       // Flag to indicate if we are splitting the root node
>       // 这个简短的注释是chatGPT总结后加上的...
>       splittingRoot := isNodeRoot(oldNode)
> 
>       var parent, newNode []byte
>       if splittingRoot {
>               createNewRoot(table, newPageNum)
>               parent = getPage(table.pager, table.rootPageNum)
>               // If splitting root, update oldNode to point to the left child of the new root
>               oldPageNum = *internalNodeChild(parent, 0)
>               oldNode = getPage(table.pager, oldPageNum)
>       } else {
>               parent = getPage(table.pager, *nodeParent(oldNode))
>               newNode = getPage(table.pager, newPageNum)
>               initializeInternalNode(newNode)
>       }
> 
>       oldNumKeys := internalNodeNumKeys(oldNode)
> 
>       curPageNum := *internalNodeRightChild(oldNode)
>       cur := getPage(table.pager, curPageNum)
> 
>       // Move the right child into the new node and set the right child of old node to INVALID_PAGE_NUM
>       internalNodeInsert(table, newPageNum, curPageNum)
>       *nodeParent(cur) = newPageNum
>       *internalNodeRightChild(oldNode) = INVALID_PAGE_NUM
> 
>       // Move keys and child nodes to the new node until the middle key
>       for i := INTERNAL_NODE_MAX_CELLS - 1; i > INTERNAL_NODE_MAX_CELLS/2; i-- {
>               curPageNum = *internalNodeChild(oldNode, uint32(i))
>               cur = getPage(table.pager, curPageNum)
> 
>               internalNodeInsert(table, newPageNum, curPageNum)
>               *nodeParent(cur) = newPageNum
> 
>               (*oldNumKeys)--
>       }
> 
>       // Set the right child of old node to the highest key before the middle key and decrement the number of keys
>       *internalNodeRightChild(oldNode) = *internalNodeChild(oldNode, *oldNumKeys-1)
>       (*oldNumKeys)--
> 
>       // Determine which of the split nodes should contain the child to be inserted
>       maxAfterSplit := getNodeMaxKey(table.pager, oldNode)
>       destinationPageNum := newPageNum
> 
>       if childMax < maxAfterSplit {
>               destinationPageNum = oldPageNum
>       }
> 
>       // Insert the child node into the appropriate split node
>       internalNodeInsert(table, destinationPageNum, childPageNum)
>       *nodeParent(child) = destinationPageNum
> 
>       // Update the parent node's key to reflect the new highest key in the old node
>       updateInternalNodeKey(parent, oldMax, getNodeMaxKey(table.pager, oldNode))
> 
>       // If not splitting the root, insert the new node into its parent
>       if !splittingRoot {
>               internalNodeInsert(table, *nodeParent(oldNode), newPageNum)
>               *nodeParent(newNode) = *nodeParent(oldNode)
>       }
> }
> 
419,421c561,564
<               for i := uint32(0); i < numKeys; i++ {
<                       child = *internalNodeChild(node, i)
<                       printTree(pager, child, indentationLevel+1)
---
>               if numKeys > 0 {
>                       for i := uint32(0); i < numKeys; i++ {
>                               child = *internalNodeChild(node, i)
>                               printTree(pager, child, indentationLevel+1)
423,424c566,570
<                       indent(indentationLevel + 1)
<                       fmt.Printf("- key %d\n", *internalNodeKey(node, i))
---
>                               indent(indentationLevel + 1)
>                               fmt.Printf("- key %d\n", *internalNodeKey(node, i))
>                       }
>                       child = *internalNodeRightChild(node)
>                       printTree(pager, child, indentationLevel+1)
426,427d571
<               child = *internalNodeRightChild(node)
<               printTree(pager, child, indentationLevel+1)
629c773
< func readInput(reader *bufio.Reader, inputBuffer *InputBuffer) {
---
> func readInput(reader *bufio.Reader, table *Table, inputBuffer *InputBuffer) {
633a778,781
>               dbClose(table)
>               if err == io.EOF {
>                       os.Exit(0)
>               }
716,738d863
< func createNewRoot(table *Table, rightChildPageNum uint32) {
<       root := getPage(table.pager, table.rootPageNum)
<       rightChild := getPage(table.pager, rightChildPageNum)
<       leftChildPageNum := getUnusedPageNum(table.pager)
<       leftChild := getPage(table.pager, leftChildPageNum)
< 
<       // Left child gets data copied from the old root
<       copy(leftChild, root[:])
<       setNodeRoot(leftChild, false)
< 
<       // Root becomes a new internal node with one key and two children
<       initializeInternalNode(root)
<       setNodeRoot(root, true)
<       *internalNodeNumKeys(root) = 1
<       *internalNodeChild(root, 0) = leftChildPageNum
< 
<       leftChildMaxKey := getNodeMaxKey(leftChild)
<       *internalNodeKey(root, 0) = leftChildMaxKey
<       *internalNodeRightChild(root) = rightChildPageNum
<       *nodeParent(leftChild) = table.rootPageNum
<       *nodeParent(rightChild) = table.rootPageNum
< }
< 
744c869
<       oldMax := getNodeMaxKey(oldNode)
---
>       oldMax := getNodeMaxKey(cursor.table.pager, oldNode)
784c909
<               newMax := getNodeMaxKey(oldNode)
---
>               newMax := getNodeMaxKey(cursor.table.pager, oldNode)
843a969
>       i := 0
847a974
>               i++
848a976
>       fmt.Printf("total_rows: %d\n", i)
876c1004,1007
<               readInput(reader, inputBuffer)
---
>               readInput(reader, table, inputBuffer)
>               if inputBuffer.inputLength == 0 {
>                       continue
>               }