修改为上证指数数据,本地验证代码可运行,其中gpt2需要自行科学上网*载下**,参考代码如下:
import yfinance as yf
import numpy as np
import matplotlib.pyplot as plt
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer
from adata import stock
import pandas as pd
from transformers import AutoTokenizer, AutoModel
# 获取上证日k
sh_df = stock.market.get_market_index(index_code='000001', k_type=1, start_date='2024-01-01')
print(sh_df)
prices = sh_df["close"].tolist()
print(prices)
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")
encoded_prices = tokenizer.encode(" ".join([str(price) for price in prices]), return_tensors="pt")
model.train()
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)
model.resize_token_embeddings(len(tokenizer))
for _ in range(3):
model.zero_grad()
outputs = model(encoded_prices, labels=encoded_prices)
loss = outputs.loss
loss.backward()
optimizer.step()
generated = model.generate(encoded_prices, max_length=len(encoded_prices[0]) + 10, temperature=1.0, num_return_sequences=1)
print(generated)
generated_prices = tokenizer.decode(generated[0], skip_special_tokens=True).split()
print(generated_prices)
plt.figure(figsize=(12, 6))
plt.plot(sh_df["trade_date"], prices, label="Historical Prices")
# 绘制预测价格
last_date = sh_df['trade_date'].iloc[-1]
new_dates = [pd.to_datetime(last_date) + pd.Timedelta(days=i) for i in range(1, 4)]
new_dates_df = pd.DataFrame(new_dates, columns=['trade_date'])
new_dates_df['trade_date'] = new_dates_df['trade_date'].dt.strftime('%Y-%m-%d')
old_dates_df = pd.DataFrame(sh_df['trade_date'])
concated_dates = pd.concat([old_dates_df, new_dates_df])
print(concated_dates)
plt.plot(concated_dates['trade_date'], [float(price) for price in generated_prices], "g^", label="Predicted Prices")
plt.xlabel("Date")
plt.ylabel("Stock Price")
plt.title(f"上证 - Historical and Predicted Stock Prices (GPT)")
plt.legend()
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()
运行结果如下:
